diff --git a/.gitignore b/.gitignore index 3624d12269612..debad77ec2ad3 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ scalastyle-output.xml R-unit-tests.log R/unit-tests.out python/lib/pyspark.zip +lint-r-report.log # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 994c7e86f8a91..0240e81c45ea2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -28,6 +28,7 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js @@ -85,3 +86,8 @@ local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* +.*Rd +help/* +html/* +INDEX +.lintr diff --git a/LICENSE b/LICENSE index d0cd0dcb4bdb7..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -948,5 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) + (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/log4j.properties b/R/log4j.properties index 701adb2a3da1d..cce8d9152d32d 100644 --- a/R/log4j.properties +++ b/R/log4j.properties @@ -19,7 +19,7 @@ log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true -log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n diff --git a/R/pkg/.lintr b/R/pkg/.lintr new file mode 100644 index 0000000000000..038236fc149e6 --- /dev/null +++ b/R/pkg/.lintr @@ -0,0 +1,2 @@ +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f9447f6c3288d..7f857222452d4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,11 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# Job group lifecycle management methods +export("setJobGroup", + "clearJobGroup", + "cancelJobGroup") + exportClasses("DataFrame") exportMethods("arrange", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0af5cb8881e35..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -38,7 +38,7 @@ setClass("DataFrame", setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { .Object@env <- new.env() .Object@env$isCached <- isCached - + .Object@sdf <- sdf .Object }) @@ -55,11 +55,11 @@ dataFrame <- function(sdf, isCached = FALSE) { ############################ DataFrame Methods ############################################## #' Print Schema of a DataFrame -#' +#' #' Prints out the schema in tree format -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname printSchema #' @export #' @examples @@ -78,11 +78,11 @@ setMethod("printSchema", }) #' Get schema object -#' +#' #' Returns the schema of this DataFrame as a structType object. -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname schema #' @export #' @examples @@ -100,9 +100,9 @@ setMethod("schema", }) #' Explain -#' +#' #' Print the logical and physical Catalyst plans to the console for debugging. -#' +#' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain @@ -169,8 +169,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) @@ -200,11 +200,11 @@ setMethod("show", "DataFrame", }) #' DataTypes -#' +#' #' Return all column names and their data types as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname dtypes #' @export #' @examples @@ -224,11 +224,11 @@ setMethod("dtypes", }) #' Column names -#' +#' #' Return all column names as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname columns #' @export #' @examples @@ -256,12 +256,12 @@ setMethod("names", }) #' Register Temporary Table -#' +#' #' Registers a DataFrame as a Temporary Table in the SQLContext -#' +#' #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table -#' +#' #' @rdname registerTempTable #' @export #' @examples @@ -306,11 +306,11 @@ setMethod("insertInto", }) #' Cache -#' +#' #' Persist with the default storage level (MEMORY_ONLY). -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname cache-methods #' @export #' @examples @@ -400,7 +400,7 @@ setMethod("repartition", signature(x = "DataFrame", numPartitions = "numeric"), function(x, numPartitions) { sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) - dataFrame(sdf) + dataFrame(sdf) }) # toJSON @@ -489,7 +489,7 @@ setMethod("distinct", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) -#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", @@ -513,11 +513,11 @@ setMethod("sample_frac", }) #' Count -#' +#' #' Returns the number of rows in a DataFrame -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname count #' @export #' @examples @@ -568,13 +568,13 @@ setMethod("collect", }) #' Limit -#' +#' #' Limit the resulting DataFrame to the number of rows specified. -#' +#' #' @param x A SparkSQL DataFrame #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. -#' +#' #' @rdname limit #' @export #' @examples @@ -593,7 +593,7 @@ setMethod("limit", }) #' Take the first NUM rows of a DataFrame and return a the results as a data.frame -#' +#' #' @rdname take #' @export #' @examples @@ -613,8 +613,8 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame #' convention in R. #' #' @param x A SparkSQL DataFrame @@ -659,11 +659,11 @@ setMethod("first", }) # toRDD() -# +# # Converts a Spark DataFrame to an RDD while preserving column names. -# +# # @param x A Spark DataFrame -# +# # @rdname DataFrame # @export # @examples @@ -1167,7 +1167,7 @@ setMethod("where", #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". @@ -1303,7 +1303,7 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname write.df +#' @rdname write.df #' @export #' @examples #'\dontrun{ @@ -1401,7 +1401,7 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @rdname describe +#' @rdname describe #' @export #' @examples #'\dontrun{ @@ -1444,7 +1444,7 @@ setMethod("describe", #' This overwrites the how parameter. #' @param cols Optional list of column names to consider. #' @return A DataFrame -#' +#' #' @rdname nafunctions #' @export #' @examples @@ -1465,7 +1465,7 @@ setMethod("dropna", if (is.null(minNonNulls)) { minNonNulls <- if (how == "any") { length(cols) } else { 1 } } - + naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", as.integer(minNonNulls), listToSeq(as.list(cols))) @@ -1488,16 +1488,16 @@ setMethod("na.omit", #' @param value Value to replace null values with. #' Should be an integer, numeric, character or named list. #' If the value is a named list, then cols is ignored and -#' value must be a mapping from column name (character) to +#' value must be a mapping from column name (character) to #' replacement value. The replacement value must be an #' integer, numeric or character. #' @param cols optional list of column names to consider. #' Columns specified in cols that do not have matching data -#' type are ignored. For example, if value is a character, and +#' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. #' @return A DataFrame -#' +#' #' @rdname nafunctions #' @export #' @examples @@ -1515,14 +1515,14 @@ setMethod("fillna", if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { stop("value should be an integer, numeric, charactor or named list.") } - + if (class(value) == "list") { # Check column names in the named list colNames <- names(value) if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - + # Convert to the named list to an environment to be passed to JVM valueMap <- new.env() for (col in colNames) { @@ -1533,19 +1533,19 @@ setMethod("fillna", } valueMap[[col]] <- v } - + # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { warning("When value is a named list, cols is ignored!") cols <- NULL } - + value <- valueMap } else if (is.integer(value)) { # Cast an integer to a numeric value <- as.numeric(value) } - + naFunctions <- callJMethod(x@sdf, "na") sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 0513299515644..89511141d3ef7 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -48,7 +48,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, # byte: The RDD stores data serialized in R. # string: The RDD stores data as strings. # row: The RDD stores the serialized rows of a DataFrame. - + # We use an environment to store mutable states inside an RDD object. # Note that R's call-by-value semantics makes modifying slots inside an # object (passed as an argument into a function, such as cache()) difficult: @@ -363,7 +363,7 @@ setMethod("collectPartition", # @description # \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. +# in a key-value pair RDD. # @examples #\dontrun{ # sc <- sparkR.init() @@ -666,7 +666,7 @@ setMethod("minimum", # rdd <- parallelize(sc, 1:10) # sumRDD(rdd) # 55 #} -# @rdname sumRDD +# @rdname sumRDD # @aliases sumRDD,RDD setMethod("sumRDD", signature(x = "RDD"), @@ -1090,11 +1090,11 @@ setMethod("sortBy", # Return: # A list of the first N elements from the RDD in the specified order. # -takeOrderedElem <- function(x, num, ascending = TRUE) { +takeOrderedElem <- function(x, num, ascending = TRUE) { if (num <= 0L) { return(list()) } - + partitionFunc <- function(part) { if (num < length(part)) { # R limitation: order works only on primitive types! @@ -1152,7 +1152,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { # @aliases takeOrdered,RDD,RDD-method setMethod("takeOrdered", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num) }) @@ -1173,7 +1173,7 @@ setMethod("takeOrdered", # @aliases top,RDD,RDD-method setMethod("top", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num, FALSE) }) @@ -1181,7 +1181,7 @@ setMethod("top", # # Aggregate the elements of each partition, and then the results for all the # partitions, using a given associative function and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param op An associative function for the folding operation. @@ -1207,7 +1207,7 @@ setMethod("fold", # # Aggregate the elements of each partition, and then the results for all the # partitions, using given combine functions and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param seqOp A function to aggregate the RDD elements. It may return a different @@ -1230,11 +1230,11 @@ setMethod("fold", # @aliases aggregateRDD,RDD,RDD-method setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), - function(x, zeroValue, seqOp, combOp) { + function(x, zeroValue, seqOp, combOp) { partitionFunc <- function(part) { Reduce(seqOp, part, zeroValue) } - + partitionList <- collect(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) @@ -1330,7 +1330,7 @@ setMethod("setName", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) +# collect(zipWithUniqueId(rdd)) # # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #} # @rdname zipWithUniqueId @@ -1426,7 +1426,7 @@ setMethod("glom", partitionFunc <- function(part) { list(part) } - + lapplyPartition(x, partitionFunc) }) @@ -1498,16 +1498,16 @@ setMethod("zipRDD", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, TRUE) }) # Cartesian product of this RDD and another one. # -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a +# Return the Cartesian product of this RDD and another one, +# that is, the RDD of all pairs of elements (a, b) where a # is in this and b is in other. -# +# # @param x An RDD. # @param other An RDD. # @return A new RDD which is the Cartesian product of these two RDDs. @@ -1515,7 +1515,7 @@ setMethod("zipRDD", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) +# sortByKey(cartesian(rdd, rdd)) # # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #} # @rdname cartesian @@ -1528,7 +1528,7 @@ setMethod("cartesian", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, FALSE) }) @@ -1598,11 +1598,11 @@ setMethod("intersection", # Zips an RDD's partitions with one (or more) RDD(s). # Same as zipPartitions in Spark. -# +# # @param ... RDDs to be zipped. # @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but +# @return A new RDD by applying a function to the zipped partitions. +# Assumes that all the RDDs have the *same number of partitions*, but # does *not* require them to have the same number of elements in each partition. # @examples #\dontrun{ @@ -1610,7 +1610,7 @@ setMethod("intersection", # rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 # rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 # rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, +# collect(zipPartitions(rdd1, rdd2, rdd3, # func = function(x, y, z) { list(list(x, y, z))} )) # # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #} @@ -1627,7 +1627,7 @@ setMethod("zipPartitions", if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } - + rrdds <- lapply(rrdds, function(rdd) { mapPartitionsWithIndex(rdd, function(partIndex, part) { print(length(part)) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 22a4b5bf86ebd..9a743a3411533 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -182,7 +182,7 @@ setMethod("toDF", signature(x = "RDD"), #' Create a DataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' #' @param sqlContext SQLContext to use @@ -238,7 +238,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' Create a DataFrame from a Parquet file. -#' +#' #' Loads a Parquet file, returning the result as a DataFrame. #' #' @param sqlContext SQLContext to use @@ -278,7 +278,7 @@ sql <- function(sqlContext, sqlQuery) { } #' Create a DataFrame from a SparkSQL Table -#' +#' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' @@ -298,7 +298,7 @@ sql <- function(sqlContext, sqlQuery) { table <- function(sqlContext, tableName) { sdf <- callJMethod(sqlContext, "table", tableName) - dataFrame(sdf) + dataFrame(sdf) } @@ -352,7 +352,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' Cache Table -#' +#' #' Caches the specified table in-memory. #' #' @param sqlContext SQLContext to use @@ -370,11 +370,11 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' } cacheTable <- function(sqlContext, tableName) { - callJMethod(sqlContext, "cacheTable", tableName) + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table -#' +#' #' Removes the specified table from the in-memory cache. #' #' @param sqlContext SQLContext to use diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 23dc38780716e..2403925b267c8 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -27,9 +27,9 @@ # @description Broadcast variables can be created using the broadcast # function from a \code{SparkContext}. # @rdname broadcast-class -# @seealso broadcast +# @seealso broadcast # -# @param id Id of the backing Spark broadcast variable +# @param id Id of the backing Spark broadcast variable # @export setClass("Broadcast", slots = list(id = "character")) @@ -68,7 +68,7 @@ setMethod("value", # variable on workers. Not intended for use outside the package. # # @rdname broadcast-internal -# @seealso broadcast, value +# @seealso broadcast, value # @param bcastId The id of broadcast variable to set # @param value The value to be set diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 1281c41213e32..78c7a3037ffac 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) { con } -launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { +determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { sparkSubmitBinName = "spark-submit" } else { sparkSubmitBinName = "spark-submit.cmd" } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (packages != "") { + packages <- paste("--packages", packages) + } + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { sparkSubmitBin <- sparkSubmitBinName } - - if (jars != "") { - jars <- paste("--jars", jars) - } - - combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 80e92d3105a36..8e4b0f5bf1c4d 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -210,6 +210,22 @@ setMethod("cast", } }) +#' Match a column with given values. +#' +#' @rdname column +#' @return a matched values as a result of comparing with given values. +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", + signature(x = "Column"), + function(x, table) { + table <- listToSeq(as.list(table)) + jc <- callJMethod(x@jc, "in", table) + return(column(jc)) + }) + #' Approx Count Distinct #' #' @rdname column diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 257b435607ce8..d961bbc383688 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -18,7 +18,7 @@ # Utility functions to deserialize objects from Java. # Type mapping from Java to R -# +# # void -> NULL # Int -> integer # String -> character diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 12e09176c9f92..79055b7f18558 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -130,7 +130,7 @@ setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) -# @rdname sumRDD +# @rdname sumRDD # @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) @@ -219,7 +219,7 @@ setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD # @export -setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex @@ -364,7 +364,7 @@ setGeneric("subtract", # @rdname subtractByKey # @export -setGeneric("subtractByKey", +setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) @@ -399,15 +399,15 @@ setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname nafunctions #' @export setGeneric("dropna", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - standardGeneric("dropna") + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") }) #' @rdname nafunctions #' @export setGeneric("na.omit", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - standardGeneric("na.omit") + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") }) #' @rdname schema @@ -656,4 +656,3 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) - diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index b758481997574..8f1c68f7c4d28 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -136,4 +136,3 @@ createMethods <- function() { } createMethods() - diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index a8a25230b636d..0838a7bb35e0d 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -16,7 +16,7 @@ # # References to objects that exist on the JVM backend -# are maintained using the jobj. +# are maintained using the jobj. #' @include generics.R NULL diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 1e24286dbcae2..7f902ba8e683e 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -784,7 +784,7 @@ setMethod("sortByKey", newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) - + # Subtract a pair RDD with another pair RDD. # # Return an RDD with the pairs from x whose keys are not in other. @@ -820,7 +820,7 @@ setMethod("subtractByKey", }) # Return a subset of this RDD sampled by key. -# +# # @description # \code{sampleByKey} Create a sample of this RDD using variable sampling rates # for different keys as specified by fractions, a key to sampling rate map. diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index e442119086b17..15e2bdbd55d79 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -20,7 +20,7 @@ #' structType #' -#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' Create a structType object that contains the metadata for a DataFrame. Intended for #' use with createDataFrame and toDF. #' #' @param x a structField object (created with the field() function) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 2081786e6f833..78535eff0d2f6 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } if (writeType) { writeType(con, type) } @@ -167,7 +175,7 @@ writeGenericList <- function(con, list) { writeObject(con, elem) } } - + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 5ced7c688f98a..633b869f91784 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -43,7 +43,7 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) } - + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -81,6 +81,7 @@ sparkR.stop <- function() { #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. #' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -100,7 +101,8 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "") { + sparkRLibDir = "", + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") @@ -129,7 +131,8 @@ sparkR.init <- function( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -174,7 +177,7 @@ sparkR.init <- function( for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] } - + sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) @@ -214,7 +217,7 @@ sparkR.init <- function( #' Initialize a new SQLContext. #' -#' This function creates a SparkContext from an existing JavaSparkContext and +#' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' #' @param jsc The existing JavaSparkContext created with SparkR.init() @@ -278,3 +281,47 @@ sparkRHive.init <- function(jsc = NULL) { assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) hiveCtx } + +#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a +#' different value or cleared. +#' +#' @param sc existing spark context +#' @param groupid the ID to be assigned to job groups +#' @param description description for the the job group ID +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#'} + +setJobGroup <- function(sc, groupId, description, interruptOnCancel) { + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} + +#' Clear current job group ID and its description +#' +#' @param sc existing spark context +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' clearJobGroup(sc) +#'} + +clearJobGroup <- function(sc) { + callJMethod(sc, "clearJobGroup") +} + +#' Cancel active jobs for the specified group +#' +#' @param sc existing spark context +#' @param groupId the ID of job group to be cancelled +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' cancelJobGroup(sc, "myJobGroup") +#'} + +cancelJobGroup <- function(sc, groupId) { + callJMethod(sc, "cancelJobGroup", groupId) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 69b2700191c9a..13cec0f712fb4 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -368,21 +368,21 @@ listToSeq <- function(l) { } # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a -# user defined function (UDF), and to examine variables in the UDF to decide +# user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. # param # node The current AST node in the traversal. # oldEnv The original function environment. # defVars An Accumulator of variables names defined in the function's calling environment, # including function argument and local variable names. -# checkedFunc An environment of function objects examined during cleanClosure. It can +# checkedFunc An environment of function objects examined during cleanClosure. It can # be considered as a "name"-to-"list of functions" mapping. # newEnv A new function environment to store necessary function dependencies, an output argument. processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { nodeLen <- length(node) - + if (nodeLen > 1 && typeof(node) == "language") { - # Recursive case: current AST node is an internal node, check for its children. + # Recursive case: current AST node is an internal node, check for its children. if (length(node[[1]]) > 1) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) @@ -393,7 +393,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "<-" || nodeChar == "=" || + } else if (nodeChar == "<-" || nodeChar == "=" || nodeChar == "<<-") { # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { @@ -422,21 +422,21 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } } - } else if (nodeLen == 1 && + } else if (nodeLen == 1 && (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) - # Search in function environment, and function's enclosing environments + # Search in function environment, and function's enclosing environments # up to global environment. There is no need to look into package environments - # above the global or namespace environment that is not SparkR below the global, + # above the global or namespace environment that is not SparkR below the global, # as they are assumed to be loaded on workers. while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. - if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. @@ -444,7 +444,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) if (is.function(obj)) { # If the node is a function call. - funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) @@ -453,7 +453,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } # Function has not been examined, record it and recursively clean its closure. - assign(nodeChar, + assign(nodeChar, if (is.null(funcList[[1]])) { list(obj) } else { @@ -466,7 +466,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } } - + # Continue to search in enclosure. func.env <- parent.env(func.env) } @@ -474,8 +474,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } -# Utility function to get user defined function (UDF) dependencies (closure). -# More specifically, this function captures the values of free variables defined +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined # outside a UDF, and stores them in the function's environment. # param # func A function whose closure needs to be captured. @@ -488,7 +488,7 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { newEnv <- new.env(parent = .GlobalEnv) func.body <- body(func) oldEnv <- environment(func) - # defVars is an Accumulator of variables names defined in the function's calling + # defVars is an Accumulator of variables names defined in the function's calling # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) @@ -509,15 +509,15 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # return value # A list of two result RDDs. appendPartitionLengths <- function(x, other) { - if (getSerializedMode(x) != getSerializedMode(other) || + if (getSerializedMode(x) != getSerializedMode(other) || getSerializedMode(x) == "byte") { # Append the number of elements in each partition to that partition so that we can later # know the boundary of elements from x and other. # - # Note that this appending also serves the purpose of reserialization, because even if + # Note that this appending also serves the purpose of reserialization, because even if # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. + # may be encoded as multiple byte arrays. appendLength <- function(part) { len <- length(part) part[[len + 1]] <- len + 1 @@ -544,23 +544,23 @@ mergePartitions <- function(rdd, zip) { lengthOfValues <- part[[len]] lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - + # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") } - + if (lengthOfKeys > 1) { keys <- part[1 : (lengthOfKeys - 1)] } else { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } - + if (!zip) { return(mergeCompactLists(keys, values)) } @@ -578,6 +578,6 @@ mergePartitions <- function(rdd, zip) { part } } - + PipelinedRDD(rdd, partitionFunc) } diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R index 80d796d467943..301feade65fa3 100644 --- a/R/pkg/R/zzz.R +++ b/R/pkg/R/zzz.R @@ -18,4 +18,3 @@ .onLoad <- function(libname, pkgname) { sparkR.onLoad(libname, pkgname) } - diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 773b6ecf582d9..7189f1a260934 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -27,7 +27,21 @@ sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") assign("sqlContext", sqlContext, envir=.GlobalEnv) - cat("\n Welcome to SparkR!") + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 0000000000000..1d5c2af631aa3 Binary files /dev/null and b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar differ diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/jarTest.R new file mode 100644 index 0000000000000..d68bb20950b00 --- /dev/null +++ b/R/pkg/inst/tests/jarTest.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) + +sc <- sparkR.init() + +helloTest <- SparkR:::callJStatic("sparkR.test.hello", + "helloWorld", + "Dave") + +basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", + "addStuff", + 2L, + 2L) + +sparkR.stop() +output <- c(helloTest, basicFunction) +writeLines(output) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ca4218f3819f8..4db7266abc8e2 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -59,15 +59,15 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - + saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - + output <- collect(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName1) unlink(fileName2, recursive = TRUE) }) @@ -87,4 +87,3 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) - diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 6785a7bdae8cb..a1e354e567be5 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -30,7 +30,7 @@ mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - + fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -52,14 +52,14 @@ test_that("union on two RDDs", { test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) - expect_equal(actual, + expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) - + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) @@ -71,31 +71,31 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collect(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) - + rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collect(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) - + rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collect(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) - + unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 0000000000000..30b05c1a2afcd --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index e4aab37436a74..513bbc8e62059 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", { count(rdd3) count(rdd4) }) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R new file mode 100644 index 0000000000000..8bc693be20c3c --- /dev/null +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +context("include an external JAR in SparkContext") + +runScript <- function() { + sparkHome <- Sys.getenv("SPARK_HOME") + jarPath <- paste("--jars", + shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"))) + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + submitPath <- file.path(sparkHome, "bin/spark-submit") + res <- system2(command = submitPath, + args = c(jarPath, scriptPath), + stdout = TRUE) + tail(res, 2) +} + +test_that("sparkJars tag in SparkContext", { + testOutput <- runScript() + helloTest <- testOutput[1] + expect_true(helloTest == "Hello, Dave") + basicFunction <- testOutput[2] + expect_true(basicFunction == 4L) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 03207353c31c6..4fe653856756e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -477,7 +477,7 @@ test_that("cartesian() on RDDs", { list(1, 1), list(1, 2), list(1, 3), list(2, 1), list(2, 2), list(2, 3), list(3, 1), list(3, 2), list(3, 3))) - + # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) actual <- collect(cartesian(rdd, emptyRdd)) @@ -486,7 +486,7 @@ test_that("cartesian() on RDDs", { mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName) actual <- collect(cartesian(rdd, rdd)) expected <- list( @@ -495,7 +495,7 @@ test_that("cartesian() on RDDs", { list("Spark is pretty.", "Spark is pretty."), list("Spark is pretty.", "Spark is awesome.")) expect_equal(sortKeyValueList(actual), expected) - + rdd1 <- parallelize(sc, 0:1) actual <- collect(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), @@ -504,11 +504,11 @@ test_that("cartesian() on RDDs", { list(0, "Spark is awesome."), list(1, "Spark is pretty."), list(1, "Spark is awesome."))) - + rdd1 <- map(rdd, function(x) { x }) actual <- collect(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) - + unlink(fileName) }) @@ -760,7 +760,7 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - rdd <- parallelize(sc, list(1:10)) + rdd <- parallelize(sc, list(1:10)) expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d7dedda553c56..adf0b91d25fe9 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -106,39 +106,39 @@ test_that("aggregateByKey", { zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + actual <- collect(aggregatedRDD) - + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test aggregateByKey for string keys rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) - + zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) actual <- collect(aggregatedRDD) - + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) -test_that("foldByKey", { +test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - + actual <- collect(folded) expected <- list(list(1.5, 199), list(2.5, 101)) @@ -146,15 +146,15 @@ test_that("foldByKey", { # test foldByKey for string keys stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) - + stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 30edfc8a7bd94..6a08f894313c4 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,6 +19,14 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -52,9 +60,10 @@ test_that("infer types", { list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_true(class(testStruct) == "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -67,7 +76,7 @@ test_that("structType and structField", { expect_true(inherits(testField, "structField")) expect_true(testField$name() == "a") expect_true(testField$nullable()) - + testSchema <- structType(testField, structField("b", "integer")) expect_true(inherits(testSchema, "structType")) expect_true(inherits(testSchema$fields()[[2]], "structField")) @@ -101,6 +110,43 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) @@ -561,7 +607,7 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) @@ -656,6 +702,16 @@ test_that("filter() on a DataFrame", { filtered2 <- where(df, df$name != "Michael") expect_true(count(filtered2) == 2) expect_true(collect(filtered2)$age[2] == 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) }) test_that("join() on a DataFrame", { @@ -792,7 +848,7 @@ test_that("dropna() on a DataFrame", { rows <- collect(df) # drop with columns - + expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) expect_true(identical(expected, actual)) @@ -805,7 +861,7 @@ test_that("dropna() on a DataFrame", { expect_true(identical(expected$age, actual$age)) expect_true(identical(expected$height, actual$height)) expect_true(identical(expected$name, actual$name)) - + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) expect_true(identical(expected, actual)) @@ -813,7 +869,7 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) expect_true(identical(expected, actual)) - + # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] @@ -823,7 +879,7 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) expect_true(identical(expected, actual)) - + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) expect_true(identical(expected, actual)) @@ -835,14 +891,14 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) expect_true(identical(expected, actual)) - + # drop with threshold - + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_true(identical(expected, actual)) - expected <- rows[as.integer(!is.na(rows$age)) + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) @@ -852,9 +908,9 @@ test_that("dropna() on a DataFrame", { test_that("fillna() on a DataFrame", { df <- jsonFile(sqlContext, jsonPathNa) rows <- collect(df) - + # fill with value - + expected <- rows expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 @@ -875,7 +931,7 @@ test_that("fillna() on a DataFrame", { expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) expect_true(identical(expected, actual)) - + # fill with named list expected <- rows @@ -883,7 +939,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_true(identical(expected, actual)) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index 7f4c7c315d787..c5eb417b40159 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -64,4 +64,3 @@ test_that("take() gives back the original elements in correct count and order", expect_true(length(take(numListRDD, 0)) == 0) expect_true(length(take(numVectorRDD, 0)) == 0) }) - diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 6b87b4b3e0b08..092ad9dc10c2e 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -58,7 +58,7 @@ test_that("textFile() word count works as expected", { expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName) }) @@ -115,13 +115,13 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - + output <- collect(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) - + unlink(fileName1) unlink(fileName2) }) @@ -159,4 +159,3 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) - diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 539e3a3c19df3..15030e6f1d77e 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -43,13 +43,13 @@ test_that("serializeToBytes on RDD", { mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + text.rdd <- textFile(sc, fileName) expect_true(getSerializedMode(text.rdd) == "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) expect_true(getSerializedMode(ser.rdd) == "byte") - + unlink(fileName) }) @@ -64,7 +64,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, y) actual <- get("g", envir = env, inherits = FALSE) expect_equal(actual, g) - + # Test for nested enclosures and package variables. env2 <- new.env() funcEnv <- new.env(parent = env2) @@ -106,7 +106,7 @@ test_that("cleanClosure on R functions", { expect_equal(length(ls(env)), 1) actual <- get("y", envir = env, inherits = FALSE) expect_equal(actual, y) - + # Test for function (and variable) definitions. f <- function(x) { g <- function(y) { y * 2 } @@ -115,7 +115,7 @@ test_that("cleanClosure on R functions", { newF <- cleanClosure(f) env <- environment(newF) expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. - + # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) @@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", { actual <- collect(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) - + # Test for broadcast variables. a <- matrix(nrow=10, ncol=10, data=rnorm(100)) aBroadcast <- broadcast(sc, a) diff --git a/build/mvn b/build/mvn index 3561110a4c019..e8364181e8230 100755 --- a/build/mvn +++ b/build/mvn @@ -69,11 +69,14 @@ install_app() { # Install maven under the build/ folder install_mvn() { + local MVN_VERSION="3.3.3" + install_app \ - "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ - "apache-maven-3.2.5-bin.tar.gz" \ - "apache-maven-3.2.5/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "apache-maven-${MVN_VERSION}-bin.tar.gz" \ + "apache-maven-${MVN_VERSION}/bin/mvn" + + MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" } # Install zinc under the build/ folder @@ -105,28 +108,16 @@ install_scala() { SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" } -# Determines if a given application is already installed. If not, will attempt -# to install -## Arg1 - application name -## Arg2 - Alternate path to local install under build/ dir -check_and_install_app() { - # create the local environment variable in uppercase - local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN" - # some black magic to set the generated app variable (i.e. MVN_BIN) into the - # environment - eval "${app_bin}=`which $1 2>/dev/null`" - - if [ -z "`which $1 2>/dev/null`" ]; then - install_$1 - fi -} - # Setup healthy defaults for the Zinc port if none were provided from # the environment ZINC_PORT=${ZINC_PORT:-"3030"} -# Check and install all applications necessary to build Spark -check_and_install_app "mvn" +# Install Maven if necessary +MVN_BIN="$(command -v mvn)" + +if [ ! "$MVN_BIN" ]; then + install_mvn +fi # Install the proper version of Scala and Zinc for the build install_zinc diff --git a/core/pom.xml b/core/pom.xml index 40a64beccdc24..aee0d92620606 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark @@ -354,7 +344,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad7eb04afcd8c..764578b181422 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -139,6 +139,9 @@ public void write(Iterator> records) throws IOException { @Override public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we ecountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. boolean success = false; try { while (records.hasNext()) { @@ -147,8 +150,19 @@ public void write(scala.collection.Iterator> records) throws IOEx closeAndWriteOutput(); success = true; } finally { - if (!success) { - sorter.cleanupAfterError(); + if (sorter != null) { + try { + sorter.cleanupAfterError(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } } } } diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 0000000000000..b146f8a784127 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index e96af8768daa0..9fa53baaf4212 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -140,7 +140,8 @@ function renderDagViz(forJob) { // Find cached RDDs and mark them as such metadataContainer().selectAll(".cached-rdd").each(function(v) { - var nodeId = VizConstants.nodePrefix + d3.select(this).text(); + var rddId = d3.select(this).text().trim(); + var nodeId = VizConstants.nodePrefix + rddId; svg.selectAll("g." + nodeId).classed("cached", true); }); @@ -150,7 +151,7 @@ function renderDagViz(forJob) { /* Render the RDD DAG visualization on the stage page. */ function renderDagVizForStage(svgContainer) { var metadata = metadataContainer().select(".stage-metadata"); - var dot = metadata.select(".dot-file").text(); + var dot = metadata.select(".dot-file").text().trim(); var containerId = VizConstants.graphPrefix + metadata.attr("stage-id"); var container = svgContainer.append("g").attr("id", containerId); renderDot(dot, container, false); @@ -235,7 +236,7 @@ function renderDagVizForJob(svgContainer) { // them separately later. Note that we cannot draw them now because we need to // put these edges in a separate container that is on top of all stage graphs. metadata.selectAll(".incoming-edge").each(function(v) { - var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] crossStageEdges.push(edge); }); }); diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d55643..7fcb7830e7b0b 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,13 +121,25 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..862ffe868f58f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 8aed1e20e0686..673ef49e7c1c5 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -192,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" - private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -365,10 +365,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) cookie } else { // user must have set spark.authenticate.secret config - sparkConf.getOption("spark.authenticate.secret") match { + // For Master/Worker, auth secret is in conf; for Executors, it is in env variable + sys.env.get(SecurityManager.ENV_AUTH_SECRET) + .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value case None => throw new Exception("Error: a secret key must be specified via the " + - "spark.authenticate.secret config") + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } sCookie @@ -449,3 +451,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() } + +private[spark] object SecurityManager { + + val SPARK_AUTH_CONF: String = "spark.authenticate" + val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" + // This is used to set auth secret to an executor's env variable. It should have the same + // value as SPARK_AUTH_SECERET_CONF set in SparkConf + val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" +} diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index cb2cae185256a..beb2e27254725 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -41,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 46d72841dccce..6cf36fbbd6254 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -557,7 +557,7 @@ private[spark] object SparkConf extends Logging { def isExecutorStartupConf(name: String): Boolean = { isAkkaConf(name) || name.startsWith("spark.akka") || - name.startsWith("spark.auth") || + (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a453c9bf4864a..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId @@ -545,7 +553,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) @@ -974,7 +981,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a185954089528..b0665570e2681 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,6 +20,8 @@ package org.apache.spark import java.io.File import java.net.Socket +import akka.actor.ActorSystem + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -75,7 +77,8 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + @deprecated("Actor system is no longer supported as of 1.4") + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 59ac82ccec53b..f5dd36cbcfe6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d4756..b959b683d1674 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,8 +19,8 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -61,7 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 55a37f8c944b2..dc9f62f39e6d5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -36,7 +36,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal @@ -445,7 +445,7 @@ private[spark] object PythonRDD extends Logging { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -471,7 +471,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -497,7 +497,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -540,7 +540,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -566,7 +566,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index d24c650d37bb0..1a5f2bca26c2b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} /** * Netty-based backend server that is used to communicate between R and Java. @@ -41,7 +41,8 @@ private[spark] class RBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - bossGroup = new NioEventLoopGroup(2) + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) val workerGroup = bossGroup val handler = new RBackendHandler(this) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 2e86984c66b3a..4b8f7fe9242e0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -88,6 +88,21 @@ private[r] class RBackendHandler(server: RBackend) ctx.close() } + // Looks up a class given a class name. This function first checks the + // current class loader and if a class is not found, it looks up the class + // in the context class loader. Address [SPARK-5185] + def getStaticClass(objId: String): Class[_] = { + try { + val clsCurrent = Class.forName(objId) + clsCurrent + } catch { + // Use contextLoader if we can't find the JAR in the system class loader + case e: ClassNotFoundException => + val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) + clsContext + } + } + def handleMethodCall( isStatic: Boolean, objId: String, @@ -98,7 +113,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + getStaticClass(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934ff..524676544d6f5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -391,7 +391,7 @@ private[r] object RRDD { } private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index f8e3f1a79082e..56adc857d4ce0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Date, Time} +import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConversions._ @@ -107,9 +107,12 @@ private[spark] object SerDe { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Time = { - val t = in.readDouble() - new Time((t * 1000L).toLong) + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { @@ -227,6 +230,9 @@ private[spark] object SerDe { case "java.sql.Time" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -289,6 +295,9 @@ private[spark] object SerDe { out.writeDouble(value.getTime.toDouble / 1000.0) } + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a0eae774268ed..b1d6ec209d62b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -35,7 +35,8 @@ import org.apache.ivy.core.resolve.ResolveOptions import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher -import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} +import org.apache.ivy.plugins.repository.file.FileRepository +import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -324,55 +325,20 @@ object SparkSubmit { // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn mode for a python app, add pyspark archives to files - // that can be distributed with the job - if (args.isPython && clusterManager == YARN) { - var pyArchives: String = null - val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH") - if (pyArchivesEnvOpt.isDefined) { - pyArchives = pyArchivesEnvOpt.get - } else { - if (!sys.env.contains("SPARK_HOME")) { - printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.") - } - val pythonPath = new ArrayBuffer[String] - for (sparkHome <- sys.env.get("SPARK_HOME")) { - val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) - val pyArchivesFile = new File(pyLibPath, "pyspark.zip") - if (!pyArchivesFile.exists()) { - printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.") - } - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") - if (!py4jFile.exists()) { - printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " + - "in yarn mode.") - } - pythonPath += pyArchivesFile.getAbsolutePath() - pythonPath += py4jFile.getAbsolutePath() - } - pyArchives = pythonPath.mkString(",") - } - - pyArchives = pyArchives.split(",").map { localPath => - val localURI = Utils.resolveURI(localPath) - if (localURI.getScheme != "local") { - args.files = mergeFileLists(args.files, localURI.toString) - new Path(localPath).getName - } else { - localURI.getPath - } - }.mkString(File.pathSeparator) - sysProps("spark.submit.pyArchives") = pyArchives - } - // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -386,19 +352,10 @@ object SparkSubmit { } } - if (isYarnCluster) { - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) - } - + if (isYarnCluster && args.isR) { // In yarn-cluster mode for a R app, add primary resource to files // that can be distributed with the job - if (args.isR) { - args.files = mergeFileLists(args.files, args.primaryResource) - } + args.files = mergeFileLists(args.files, args.primaryResource) } // Special flag to avoid deprecation warnings at the client @@ -515,17 +472,18 @@ object SparkSubmit { } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN && args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { @@ -778,8 +736,14 @@ private[spark] object SparkSubmitUtils { } /** Path of the local Maven cache. */ - private[spark] def m2Path: File = new File(System.getProperty("user.home"), - ".m2" + File.separator + "repository" + File.separator) + private[spark] def m2Path: File = { + if (Utils.isTesting) { + // test builds delete the maven cache, and this can cause flakiness + new File("dummy", ".m2" + File.separator + "repository") + } else { + new File(System.getProperty("user.home"), ".m2" + File.separator + "repository") + } + } /** * Extracts maven coordinates from a comma-delimited string @@ -792,6 +756,20 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -799,12 +777,13 @@ private[spark] object SparkSubmitUtils { localM2.setName("local-m2-cache") cr.add(localM2) - val localIvy = new IBiblioResolver - localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, - "local" + File.separator).toURI.toString) + val localIvy = new FileSystemResolver + val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") + localIvy.setLocal(true) + localIvy.setRepository(new FileRepository(localIvyRoot)) val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) - localIvy.setPattern(ivyPattern) + localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) localIvy.setName("local-ivy-cache") cr.add(localIvy) @@ -821,20 +800,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -875,11 +840,7 @@ private[spark] object SparkSubmitUtils { ivyConfName: String, md: DefaultModuleDescriptor): Unit = { // Add scala exclusion rule - val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") - val scalaDependencyExcludeRule = - new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) - scalaDependencyExcludeRule.addConfiguration(ivyConfName) - md.addExcludeRule(scalaDependencyExcludeRule) + md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and // other spark-streaming utility components. Underscore is there to differentiate between @@ -888,13 +849,8 @@ private[spark] object SparkSubmitUtils { "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") components.foreach { comp => - val sparkArtifacts = - new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*") - val sparkDependencyExcludeRule = - new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) - sparkDependencyExcludeRule.addConfiguration(ivyConfName) - - md.addExcludeRule(sparkDependencyExcludeRule) + md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, + ivyConfName)) } } @@ -907,6 +863,7 @@ private[spark] object SparkSubmitUtils { * @param coordinates Comma-delimited string of maven coordinates * @param remoteRepos Comma-delimited string of remote repositories other than maven central * @param ivyPath The path to the local ivy repository + * @param exclusions Exclusions to apply when resolving transitive dependencies * @return The comma-delimited path to the jars of the given maven artifacts including their * transitive dependencies */ @@ -914,6 +871,7 @@ private[spark] object SparkSubmitUtils { coordinates: String, remoteRepos: Option[String], ivyPath: Option[String], + exclusions: Seq[String] = Nil, isTest: Boolean = false): String = { if (coordinates == null || coordinates.trim.isEmpty) { "" @@ -964,6 +922,15 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library @@ -971,6 +938,10 @@ private[spark] object SparkSubmitUtils { // add all supplied maven artifacts as dependencies addDependenciesToIvy(md, artifacts, ivyConfName) + exclusions.foreach { e => + md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) + } + // resolve dependencies val rr: ResolveReport = ivy.resolve(md, resolveOptions) if (rr.hasError) { @@ -987,6 +958,18 @@ private[spark] object SparkSubmitUtils { } } } + + private def createExclusion( + coords: String, + ivySettings: IvySettings, + ivyConfName: String): ExcludeRule = { + val c = extractMavenCoordinates(coords)(0) + val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*") + val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null) + rule.addConfiguration(ivyConfName) + rule + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 0a1d60f58bc58..45a3f43045437 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.Map import org.apache.spark.Logging +import org.apache.spark.SecurityManager import org.apache.spark.deploy.Command import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils @@ -40,12 +41,14 @@ object CommandUtils extends Logging { */ def buildProcessBuilder( command: Command, + securityMgr: SecurityManager, memory: Int, sparkHome: String, substituteArguments: String => String, classPaths: Seq[String] = Seq[String](), env: Map[String, String] = sys.env): ProcessBuilder = { - val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val localCommand = buildLocalCommand( + command, securityMgr, substituteArguments, classPaths, env) val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) val builder = new ProcessBuilder(commandSeq: _*) val environment = builder.environment() @@ -69,6 +72,7 @@ object CommandUtils extends Logging { */ private def buildLocalCommand( command: Command, + securityMgr: SecurityManager, substituteArguments: String => String, classPath: Seq[String] = Seq[String](), env: Map[String, String]): Command = { @@ -76,20 +80,26 @@ object CommandUtils extends Logging { val libraryPathEntries = command.libraryPathEntries val cmdLibraryPath = command.environment.get(libraryPathName) - val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) } else { command.environment } + // set auth secret to env variable if needed + if (securityMgr.isAuthenticationEnabled) { + newEnvironment += (SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey) + } + Command( command.mainClass, command.arguments.map(substituteArguments), newEnvironment, command.classPathEntries ++ classPath, Seq[String](), // library path already captured in environment variable - command.javaOpts) + // filter out auth secret from java options + command.javaOpts.filterNot(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 856f1c36a9a51..ec51c3d935d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -85,8 +85,8 @@ private[deploy] class DriverRunner( } // TODO: If we add ability to submit multiple jars they should also be added here - val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager, + driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 295fe42b34b72..29a5042285578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.util.Utils @@ -125,8 +125,8 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 38b61d7242fce..a3b4561b07e7f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -94,8 +94,8 @@ class TaskMetrics extends Serializable { */ private var _diskBytesSpilled: Long = _ def diskBytesSpilled: Long = _diskBytesSpilled - def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index a4715e3437d94..33e6998b2cb10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -21,13 +21,12 @@ import java.io.IOException import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -38,7 +37,7 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @@ -87,7 +86,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T: ClassTag]( path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1 )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get @@ -135,7 +134,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T]( path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], context: TaskContext ): Iterator[T] = { val env = SparkEnv.get @@ -164,7 +163,7 @@ private[spark] object CheckpointRDD extends Logging { val path = new Path(hdfsPath, "temp") val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cefe63d44b20..bee59a437f120 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -100,7 +100,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp @DeveloperApi class HadoopRDD[K, V]( @transient sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -121,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 84456d6d868dc..f827270ee6a44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index cfd3e26faf2b9..91a6a2d039852 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1002,7 +1002,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1065,7 +1065,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 1722c27e55003..acbd31aacdf59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -91,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) + new SerializableConfiguration(rdd.context.hadoopConfiguration)) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 6b09dfafc889c..44667281c1063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -95,10 +95,9 @@ private[spark] object RDDOperationScope extends Logging { private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val stackTrace = Thread.currentThread.getStackTrace().tail // ignore "Thread#getStackTrace" - val ourMethodName = stackTrace(1).getMethodName // i.e. withScope - // Climb upwards to find the first method that's called something different - val callerMethodName = stackTrace + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) .find(_.getMethodName != ourMethodName) .map(_.getMethodName) .getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 75a567fb31520..a7cf0c23d9613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -81,6 +81,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -137,6 +139,22 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -889,22 +907,29 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return } if (tasks.size > 0) { @@ -1384,17 +1409,32 @@ class DAGScheduler( if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. + rdd.dependencies.foreach { case n: NarrowDependency[_] => + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } + case s: ShuffleDependency[_, _, _] => + // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION + // of data as preferred locations + if (shuffleLocalityEnabled && + rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && + s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => } Nil @@ -1407,17 +1447,29 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 02c67073af6a0..6b667d5d7645b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge, MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -45,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cd8a82347a1e9..ed35cffe968f8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -36,7 +36,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -94,8 +94,10 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index bb5db545531d2..cc2f0506817d3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.io._ import java.lang.reflect.{Field, Method} import java.security.AccessController @@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging { * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ - def find(obj: Any): List[String] = { + private[serializer] def find(obj: Any): List[String] = { new SerializationDebugger().visit(obj, List.empty) } @@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging { return List.empty } + /** + * Visit an externalizable object. + * Since writeExternal() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutput that collects all the relevant objects for further testing. + */ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = { val fieldList = new ListObjectOutput @@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging { // An object contains multiple slots in serialization. // Get the slots and visit fields in all of them. val (finalObj, desc) = findObjectAndDescriptor(o) + + // If the object has been replaced using writeReplace(), + // then call visit() on it again to test its type again. + if (!finalObj.eq(o)) { + return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) + } + + // Every class is associated with one or more "slots", each slot refers to the parent + // classes of this class. These slots are used by the ObjectOutputStream + // serialization code to recursively serialize the fields of an object and + // its parent classes. For example, if there are the following classes. + // + // class ParentClass(parentField: Int) + // class ChildClass(childField: Int) extends ParentClass(1) + // + // Then serializing the an object Obj of type ChildClass requires first serializing the fields + // of ParentClass (that is, parentField), and then serializing the fields of ChildClass + // (that is, childField). Correspondingly, there will be two slots related to this object: + // + // 1. ParentClass slot, which will be used to serialize parentField of Obj + // 2. ChildClass slot, which will be used to serialize childField fields of Obj + // + // The following code uses the description of each slot to find the fields in the + // corresponding object to visit. + // val slotDescs = desc.getSlotDescs var i = 0 while (i < slotDescs.length) { val slotDesc = slotDescs(i) if (slotDesc.hasWriteObjectMethod) { - // TODO: Handle classes that specify writeObject method. + // If the class type corresponding to current slot has writeObject() defined, + // then its not obvious which fields of the class will be serialized as the writeObject() + // can choose arbitrary fields for serialization. This case is handled separately. + val elem = s"writeObject data (class: ${slotDesc.getName})" + val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) + if (childStack.nonEmpty) { + return childStack + } } else { + // Visit all the fields objects of the class corresponding to the current slot. val fields: Array[ObjectStreamField] = slotDesc.getFields val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val numPrims = fields.length - objFieldValues.length - desc.getObjFieldValues(finalObj, objFieldValues) + slotDesc.getObjFieldValues(finalObj, objFieldValues) var j = 0 while (j < objFieldValues.length) { @@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging { } j += 1 } - } i += 1 } return List.empty } + + /** + * Visit a serializable object which has the writeObject() defined. + * Since writeObject() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutputStream that collects all the relevant fields for further testing. + * This is similar to how externalizable objects are visited. + */ + private def visitSerializableWithWriteObjectMethod( + o: Object, stack: List[String]): List[String] = { + val innerObjectsCatcher = new ListObjectOutputStream + var notSerializableFound = false + try { + innerObjectsCatcher.writeObject(o) + } catch { + case io: IOException => + notSerializableFound = true + } + + // If something was not serializable, then visit the captured objects. + // Otherwise, all the captured objects are safely serializable, so no need to visit them. + // As an optimization, just added them to the visited list. + if (notSerializableFound) { + val innerObjects = innerObjectsCatcher.outputArray + var k = 0 + while (k < innerObjects.length) { + val childStack = visit(innerObjects(k), stack) + if (childStack.nonEmpty) { + return childStack + } + k += 1 + } + } else { + visited ++= innerObjectsCatcher.outputArray + } + return List.empty + } } /** * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * writeReplace in Serializable. It starts with the object itself, and keeps calling the - * writeReplace method until there is no more + * writeReplace method until there is no more. */ @tailrec private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { @@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging { override def writeByte(i: Int): Unit = {} } + /** An output stream that emulates /dev/null */ + private class NullOutputStream extends OutputStream { + override def write(b: Int) { } + } + + /** + * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns + * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` + * method which gets called on every object, only if replacing is enabled. So this subclass + * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that + * are being serializabled. The serialized bytes are ignored by sending them to a + * [[NullOutputStream]], which acts like a /dev/null. + */ + private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { + private val output = new mutable.ArrayBuffer[Any] + this.enableReplaceObject(true) + + def outputArray: Array[Any] = output.toArray + + override def replaceObject(obj: Object): Object = { + output += obj + obj + } + } + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { def getSlotDescs: Array[ObjectStreamClass] = { diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f1bdff96d3df1..bd2704dc81871 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -182,6 +182,7 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index f2bfef376d3ca..df7bbd64247dd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -56,9 +56,6 @@ private[spark] object UnsafeShuffleManager extends Logging { } else if (dependency.aggregator.isDefined) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") false - } else if (dependency.keyOrdering.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") - false } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5048c7dab240b..1beafa1771448 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -83,8 +83,10 @@ private[spark] class BlockManager( private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val externalBlockStore: ExternalBlockStore = + private[spark] lazy val externalBlockStore: ExternalBlockStore = { + externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) + } private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..e49e39679e940 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18e..e2d25e36365fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 65162f4fdcd62..7898039519201 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -362,7 +362,7 @@ private[spark] object UIUtils extends Logging { { g.incomingEdges.map { e =>
{e.fromId},{e.toId}
} } { g.outgoingEdges.map { e =>
{e.fromId},{e.toId}
} } { - g.rootCluster.getAllNodes.filter(_.cached).map { n => + g.rootCluster.getCachedNodes.map { n =>
{n.id}
} } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b247e4cdc3bd4..01cddda4c62cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -67,7 +67,7 @@ private[ui] class ExecutorsPage( Executor ID Address RDD Blocks - Memory Used + Storage Memory Disk Used Active Tasks Failed Tasks diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 730f9806e518e..0c854f04890b6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -539,11 +539,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { /** * For testing only. Wait until at least `numExecutors` executors are up, or throw * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. + * Exposed for testing. * * @param numExecutors the number of executors to wait at least * @param timeout time to wait in milliseconds */ - @VisibleForTesting private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { val finishTime = System.currentTimeMillis() + timeout while (System.currentTimeMillis() < finishTime) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index b83a49f79c8a8..e96bf49d0dd14 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -572,55 +572,55 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val attempt = taskInfo.attempt val timelineObject = s""" - { - 'className': 'task task-assignment-timeline-object', - 'group': '$executorId', - 'content': '
' + - 'Status: ${taskInfo.status}
' + - 'Launch Time: ${UIUtils.formatDate(new Date(launchTime))}' + - '${ + |{ + |'className': 'task task-assignment-timeline-object', + |'group': '$executorId', + |'content': '
+ |Status: ${taskInfo.status}
+ |Launch Time: ${UIUtils.formatDate(new Date(launchTime))} + |${ if (!taskInfo.running) { s"""
Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" } else { "" } - }' + - '
Scheduler Delay: $schedulerDelay ms' + - '
Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)}' + - '
Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)}' + - '
Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)}' + - '
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)}' + - '
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)}' + - '
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}">' + - '' + - '' + - '' + - '' + - '' + - '' + - '' + - '', - 'start': new Date($launchTime), - 'end': new Date($finishTime) - } - """ + } + |
Scheduler Delay: $schedulerDelay ms + |
Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)} + |
Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)} + |
Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)} + |
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} + |
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} + |
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> + | + | + | + | + | + | + | + |', + |'start': new Date($launchTime), + |'end': new Date($finishTime) + |} + |""".stripMargin.replaceAll("\n", " ") timelineObject }.mkString("[", ",", "]") diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index d6a5085db1efb..ffea9817c0b08 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -66,9 +66,9 @@ private[ui] class RDDOperationCluster(val id: String, private var _name: String) _childClusters += childCluster } - /** Return all the nodes container in this cluster, including ones nested in other clusters. */ - def getAllNodes: Seq[RDDOperationNode] = { - _childNodes ++ _childClusters.flatMap(_.childNodes) + /** Return all the nodes which are cached. */ + def getCachedNodes: Seq[RDDOperationNode] = { + _childNodes.filter(_.cached) ++ _childClusters.flatMap(_.getCachedNodes) } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6f2966bd4fd31..305de4c75539d 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -109,7 +109,14 @@ private[spark] object ClosureCleaner extends Logging { private def createNullValue(cls: Class[_]): AnyRef = { if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + cls match { + case java.lang.Boolean.TYPE => new java.lang.Boolean(false) + case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Void.TYPE => + // This should not happen because `Foo(void x) {}` does not compile. + throw new IllegalStateException("Unexpected void parameter in constructor") + case _ => new java.lang.Byte(0: Byte) + } } else { null } @@ -319,28 +326,17 @@ private[spark] object ClosureCleaner extends Logging { private def instantiateClass( cls: Class[_], enclosingObject: AnyRef): AnyRef = { - if (!Utils.isInInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (enclosingObject != null) { - params(0) = enclosingObject // First param is always enclosing object - } - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (enclosingObject != null) { - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, enclosingObject) - } - obj + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (enclosingObject != null) { + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, enclosingObject) } + obj } } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala new file mode 100644 index 0000000000000..30bcf1d2f24d5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils + +private[spark] +class SerializableConfiguration(@transient var value: Configuration) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala new file mode 100644 index 0000000000000..afbcc6efc850c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.mapred.JobConf + +import org.apache.spark.util.Utils + +private[spark] +class SerializableJobConf(@transient var value: JobConf) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new JobConf(false) + value.readFields(in) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 153ece6224a6d..a7fc749a2b0c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1804,15 +1804,10 @@ private[spark] object Utils extends Logging { lazy val isInInterpreter: Boolean = { try { - val interpClass = classForName("spark.repl.Main") + val interpClass = classForName("org.apache.spark.repl.Main") interpClass.getMethod("interp").invoke(null) != null } catch { - // Returning true seems to be a mistake. - // Currently changing it to false causes tests failures in Streaming. - // For a more detailed discussion, please, refer to - // https://github.com/apache/spark/pull/5835#issuecomment-101042271 and subsequent comments. - // Addressing this changed is tracked as https://issues.apache.org/jira/browse/SPARK-7527 - case _: ClassNotFoundException => true + case _: ClassNotFoundException => false } } @@ -2338,3 +2333,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 290282c9c2e28..4c1e16155462e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -32,12 +32,18 @@ import org.apache.spark.annotation.DeveloperApi * size, which is guaranteed to explore all spaces for each key (see * http://en.wikipedia.org/wiki/Quadratic_probing). * + * The map can support up to `375809638 (0.7 * 2 ^ 29)` elements. + * * TODO: Cache the hash values of each key? java.util.HashMap does that. */ @DeveloperApi class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { - require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + import AppendOnlyMap._ + + require(initialCapacity <= MAXIMUM_CAPACITY, + s"Can't make capacity bigger than ${MAXIMUM_CAPACITY} elements") require(initialCapacity >= 1, "Invalid initial capacity") private val LOAD_FACTOR = 0.7 @@ -206,12 +212,9 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** Double the table's size and re-hash everything */ protected def growTable() { + // capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow val newCapacity = capacity * 2 - if (newCapacity >= (1 << 30)) { - // We can't make the table this big because we want an array of 2x - // that size for our data, but array sizes are at most Int.MaxValue - throw new Exception("Can't make capacity bigger than 2^29 elements") - } + require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements") val newData = new Array[AnyRef](2 * newCapacity) val newMask = newCapacity - 1 // Insert all our old values into the new array. Note that because our old keys are @@ -292,3 +295,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) */ def atGrowThreshold: Boolean = curSize == growThreshold } + +private object AppendOnlyMap { + val MAXIMUM_CAPACITY = (1 << 29) +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 64e7102e3654c..60bf4dd7469f1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -45,7 +45,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( loadFactor: Double) extends Serializable { - require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity <= OpenHashSet.MAX_CAPACITY, + s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") require(initialCapacity >= 1, "Invalid initial capacity") require(loadFactor < 1.0, "Load factor must be less than 1.0") require(loadFactor > 0.0, "Load factor must be greater than 0.0") @@ -223,6 +224,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 + require(newCapacity > 0 && newCapacity <= OpenHashSet.MAX_CAPACITY, + s"Can't contain more than ${(loadFactor * OpenHashSet.MAX_CAPACITY).toInt} elements") allocateFunc(newCapacity) val newBitset = new BitSet(newCapacity) val newData = new Array[T](newCapacity) @@ -276,9 +279,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( private[spark] object OpenHashSet { + val MAX_CAPACITY = 1 << 30 val INVALID_POS = -1 - val NONEXISTENCE_MASK = 0x80000000 - val POSITION_MASK = 0xEFFFFFF + val NONEXISTENCE_MASK = 1 << 31 + val POSITION_MASK = (1 << 31) - 1 /** * A set of specialized hash function implementation to avoid boxing hash code computation diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 5a6e9a9580e9b..04bb7fc78c13b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -25,11 +25,16 @@ import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track * of its estimated size in bytes. + * + * The buffer can support up to `1073741823 (2 ^ 30 - 1)` elements. */ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) extends WritablePartitionedPairCollection[K, V] with SizeTracker { - require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + import PartitionedPairBuffer._ + + require(initialCapacity <= MAXIMUM_CAPACITY, + s"Can't make capacity bigger than ${MAXIMUM_CAPACITY} elements") require(initialCapacity >= 1, "Invalid initial capacity") // Basic growable array data structure. We use a single array of AnyRef to hold both the keys @@ -51,11 +56,15 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) /** Double the size of the array because we've reached capacity */ private def growArray(): Unit = { - if (capacity == (1 << 29)) { - // Doubling the capacity would create an array bigger than Int.MaxValue, so don't - throw new Exception("Can't grow buffer beyond 2^29 elements") + if (capacity >= MAXIMUM_CAPACITY) { + throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_CAPACITY} elements") } - val newCapacity = capacity * 2 + val newCapacity = + if (capacity * 2 < 0 || capacity * 2 > MAXIMUM_CAPACITY) { // Overflow + MAXIMUM_CAPACITY + } else { + capacity * 2 + } val newArray = new Array[AnyRef](2 * newCapacity) System.arraycopy(data, 0, newArray, 0, 2 * capacity) data = newArray @@ -86,3 +95,7 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } } } + +private object PartitionedPairBuffer { + val MAXIMUM_CAPACITY = Int.MaxValue / 2 // 2 ^ 30 - 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index 862408b7a4d21..ae9a48729e201 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -48,6 +48,8 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ * | keyStart | keyValLen | partitionId | * +-------------+------------+------------+-------------+ * + * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. + * * @param metaInitialRecords The initial number of entries in the metadata buffer. * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. * @param serializerInstance the serializer used for serializing inserted records. @@ -63,6 +65,8 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( " Java-serialized objects.") } + require(metaInitialRecords <= MAXIMUM_RECORDS, + s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) @@ -89,11 +93,17 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( /** Double the size of the array because we've reached capacity */ private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) { - // Doubling the capacity would create an array bigger than Int.MaxValue, so don't - throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes") + if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { + throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") } - val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2) + val newCapacity = + if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { + // Overflow + MAXIMUM_META_BUFFER_CAPACITY + } else { + metaBuffer.capacity * 2 + } + val newMetaBuffer = IntBuffer.allocate(newCapacity) newMetaBuffer.put(metaBuffer.array) metaBuffer = newMetaBuffer } @@ -247,12 +257,15 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf } } -private[spark] object PartitionedSerializedPairBuffer { +private object PartitionedSerializedPairBuffer { val KEY_START = 0 // keyStart, a long, gets split across two ints val KEY_VAL_LEN = 2 val PARTITION = 3 val RECORD_SIZE = PARTITION + 1 // num ints of metadata + val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 + val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 + def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { val lower32 = metaBuffer.get(metaBufferPos + KEY_START) val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 83d109115aa5c..10c3eedbf4b46 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -253,6 +253,23 @@ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); } + class PandaException extends RuntimeException { + } + + @Test(expected=PandaException.class) + public void writeFailurePropagates() throws Exception { + class BadRecords extends scala.collection.AbstractIterator> { + @Override public boolean hasNext() { + throw new PandaException(); + } + @Override public Product2 next() { + return null; + } + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(new BadRecords()); + } + @Test public void writeEmptyIterator() throws Exception { final UnsafeShuffleWriter writer = createWriter(true); diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fab69678d040..7a1961137cce5 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -205,4 +205,39 @@ class MapOutputTrackerSuite extends SparkFunSuite { // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } + + test("getLocationsWithLargestOutputs with multiple outputs in same machine") { + val rpcEnv = createRpcEnv("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 3 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + tracker.registerShuffle(10, 3) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L))) + + // When the threshold is 50%, only host A should be returned as a preferred location + // as it has 4 out of 7 bytes of output. + val topLocs50 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.5) + assert(topLocs50.nonEmpty) + assert(topLocs50.get.size === 1) + assert(topLocs50.get.head === BlockManagerId("a", "hostA", 1000)) + + // When the threshold is 20%, both hosts should be returned as preferred locations. + val topLocs20 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.2) + assert(topLocs20.nonEmpty) + assert(topLocs20.get.size === 2) + assert(topLocs20.get.toSet === + Seq(BlockManagerId("a", "hostA", 1000), BlockManagerId("b", "hostB", 1000)).toSet) + + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index 7d39984424842..823050b0aabbe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -24,6 +24,8 @@ import com.google.common.io.{Files, ByteStreams} import org.apache.commons.io.FileUtils +import org.apache.ivy.core.settings.IvySettings + import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate @@ -44,13 +46,30 @@ private[deploy] object IvyTestUtils { if (!useIvyLayout) { Seq(groupDirs, artifactDirs, artifact.version).mkString(File.separator) } else { - Seq(groupDirs, artifactDirs, artifact.version, ext + "s").mkString(File.separator) + Seq(artifact.groupId, artifactDirs, artifact.version, ext + "s").mkString(File.separator) } new File(prefix, artifactPath) } - private def artifactName(artifact: MavenCoordinate, ext: String = ".jar"): String = { - s"${artifact.artifactId}-${artifact.version}$ext" + /** Returns the artifact naming based on standard ivy or maven format. */ + private def artifactName( + artifact: MavenCoordinate, + useIvyLayout: Boolean, + ext: String = ".jar"): String = { + if (!useIvyLayout) { + s"${artifact.artifactId}-${artifact.version}$ext" + } else { + s"${artifact.artifactId}$ext" + } + } + + /** Returns the directory for the given groupId based on standard ivy or maven format. */ + private def getBaseGroupDirectory(artifact: MavenCoordinate, useIvyLayout: Boolean): String = { + if (!useIvyLayout) { + artifact.groupId.replace(".", File.separator) + } else { + artifact.groupId + } } /** Write the contents to a file to the supplied directory. */ @@ -92,6 +111,22 @@ private[deploy] object IvyTestUtils { createCompiledClass(className, dir, sourceFile, Seq.empty) } + private def createDescriptor( + tempPath: File, + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]], + useIvyLayout: Boolean): File = { + if (useIvyLayout) { + val ivyXmlPath = pathFromCoordinate(artifact, tempPath, "ivy", true) + Files.createParentDirs(new File(ivyXmlPath, "dummy")) + createIvyDescriptor(ivyXmlPath, artifact, dependencies) + } else { + val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) + Files.createParentDirs(new File(pomPath, "dummy")) + createPom(pomPath, artifact, dependencies) + } + } + /** Helper method to write artifact information in the pom. */ private def pomArtifactWriter(artifact: MavenCoordinate, tabCount: Int = 1): String = { var result = "\n" + " " * tabCount + s"${artifact.groupId}" @@ -121,15 +156,55 @@ private[deploy] object IvyTestUtils { "\n \n" + inside + "\n " }.getOrElse("") content += "\n" - writeFile(dir, artifactName(artifact, ".pom"), content.trim) + writeFile(dir, artifactName(artifact, false, ".pom"), content.trim) + } + + /** Helper method to write artifact information in the ivy.xml. */ + private def ivyArtifactWriter(artifact: MavenCoordinate): String = { + s"""""".stripMargin + } + + /** Create a pom file for this artifact. */ + private def createIvyDescriptor( + dir: File, + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]]): File = { + var content = s""" + | + | + | + | + | + | + | + | + | + | + | + | + | + """.stripMargin.trim + content += dependencies.map { deps => + val inside = deps.map(ivyArtifactWriter).mkString("\n") + "\n \n" + inside + "\n " + }.getOrElse("") + content += "\n" + writeFile(dir, "ivy.xml", content.trim) } /** Create the jar for the given maven coordinate, using the supplied files. */ private def packJar( dir: File, artifact: MavenCoordinate, - files: Seq[(String, File)]): File = { - val jarFile = new File(dir, artifactName(artifact)) + files: Seq[(String, File)], + useIvyLayout: Boolean): File = { + val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) @@ -187,12 +262,10 @@ private[deploy] object IvyTestUtils { } else { Seq(javaFile) } - val jarFile = packJar(jarPath, artifact, allFiles) + val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout) assert(jarFile.exists(), "Problem creating Jar file") - val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) - Files.createParentDirs(new File(pomPath, "dummy")) - val pomFile = createPom(pomPath, artifact, dependencies) - assert(pomFile.exists(), "Problem creating Pom file") + val descriptor = createDescriptor(tempPath, artifact, dependencies, useIvyLayout) + assert(descriptor.exists(), "Problem creating Pom file") } finally { FileUtils.deleteDirectory(root) } @@ -237,7 +310,10 @@ private[deploy] object IvyTestUtils { dependencies: Option[String], rootDir: Option[File], useIvyLayout: Boolean = false, - withPython: Boolean = false)(f: String => Unit): Unit = { + withPython: Boolean = false, + ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { + val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) + purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, withPython) try { @@ -245,17 +321,29 @@ private[deploy] object IvyTestUtils { } finally { // Clean up if (repo.toString.contains(".m2") || repo.toString.contains(".ivy2")) { - FileUtils.deleteDirectory(new File(repo, - artifact.groupId.replace(".", File.separator) + File.separator + artifact.artifactId)) - dependencies.map(SparkSubmitUtils.extractMavenCoordinates).foreach { seq => - seq.foreach { dep => - FileUtils.deleteDirectory(new File(repo, - dep.artifactId.replace(".", File.separator))) + val groupDir = getBaseGroupDirectory(artifact, useIvyLayout) + FileUtils.deleteDirectory(new File(repo, groupDir + File.separator + artifact.artifactId)) + deps.foreach { _.foreach { dep => + FileUtils.deleteDirectory(new File(repo, getBaseGroupDirectory(dep, useIvyLayout))) } } } else { FileUtils.deleteDirectory(repo) } + purgeLocalIvyCache(artifact, deps, ivySettings) + } + } + + /** Deletes the test packages from the ivy cache */ + private def purgeLocalIvyCache( + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]], + ivySettings: IvySettings): Unit = { + // delete the artifact from the cache as well if it already exists + FileUtils.deleteDirectory(new File(ivySettings.getDefaultCache, artifact.groupId)) + dependencies.foreach { _.foreach { dep => + FileUtils.deleteDirectory(new File(ivySettings.getDefaultCache, dep.groupId)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46ea28d0f18f6..2e05dec99b6bf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -325,7 +325,7 @@ class SparkSubmitSuite runSparkSubmit(args) } - ignore("includes jars passed in through --jars") { + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) @@ -340,7 +340,7 @@ class SparkSubmitSuite } // SPARK-7287 - ignore("includes jars passed in through --packages") { + test("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") @@ -499,9 +499,16 @@ class SparkSubmitSuite Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - failAfter(60 seconds) { process.waitFor() } - // Ensure we still kill the process in case it timed out - process.destroy() + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } } private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { @@ -541,6 +548,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -566,6 +574,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8fda5c8b472c9..c9b435a9228d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -24,13 +24,16 @@ import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings -import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.ivy.plugins.resolver.{AbstractResolver, FileSystemResolver, IBiblioResolver} import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.Utils class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { + private var tempIvyPath: String = _ + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } @@ -47,6 +50,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream + tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -64,7 +68,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { // should have central and spark-packages by default assert(res1.getResolvers.size() === 4) assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") - assert(res1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "local-ivy-cache") + assert(res1.getResolvers.get(1).asInstanceOf[FileSystemResolver].getName === "local-ivy-cache") assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") @@ -72,10 +76,10 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") - resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.getRoot === expected(i - 4)) + resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } @@ -90,52 +94,58 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(ivyPath) + val index = jPaths.indexOf(tempIvyPath) assert(index >= 0) - jPaths = jPaths.substring(index + ivyPath.length) + jPaths = jPaths.substring(index + tempIvyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + Option(tempIvyPath), isTest = true) + assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } test("search for artifact at local repositories") { - val main = new MavenCoordinate("my.awesome.lib", "mylib", "0.1") + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" // Local M2 repository - IvyTestUtils.withRepository(main, None, Some(SparkSubmitUtils.m2Path)) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + IvyTestUtils.withRepository(main, Some(dep), Some(SparkSubmitUtils.m2Path)) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, + isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } // Local Ivy Repository val settings = new IvySettings val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(ivyLocal), true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + IvyTestUtils.withRepository(main, Some(dep), Some(ivyLocal), useIvyLayout = true) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, + isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) + settings.setDefaultIvyUserDir(new File(tempIvyPath)) + IvyTestUtils.withRepository(main, Some(dep), Some(dummyIvyLocal), useIvyLayout = true, + ivySettings = settings) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) + Some(tempIvyPath), isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } test("dependency not found throws RuntimeException") { intercept[RuntimeException] { - SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, true) + SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, isTest = true) } } @@ -147,12 +157,12 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + ",org.apache.spark:spark-core_fake:1.2.0" - val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) + val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, isTest = true) assert(path === "", "should return empty path") val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") IvyTestUtils.withRepository(main, None, None) { repo => val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, - Some(repo), None, true) + Some(repo), None, isTest = true) assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 5b3930c0b0132..7101cb9978df3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,21 +17,52 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.Matchers +import org.scalatest.{Matchers, PrivateMethodTester} -class CommandUtilsSuite extends SparkFunSuite with Matchers { +class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester { test("set libraryPath correctly") { val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq()) - val builder = CommandUtils.buildProcessBuilder(cmd, 512, sparkHome, t => t) + val builder = CommandUtils.buildProcessBuilder( + cmd, new SecurityManager(new SparkConf), 512, sparkHome, t => t) val libraryPath = Utils.libraryPathEnvName val env = builder.environment env.keySet should contain(libraryPath) assert(env.get(libraryPath).startsWith("libraryPathToB")) } + + test("auth secret shouldn't appear in java opts") { + val buildLocalCommand = PrivateMethod[Command]('buildLocalCommand) + val conf = new SparkConf + val secret = "This is the secret sauce" + // set auth secret + conf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, secret) + val command = new Command("mainClass", Seq(), Map(), Seq(), Seq("lib"), + Seq("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF + "=" + secret)) + + // auth is not set + var cmd = CommandUtils invokePrivate buildLocalCommand( + command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) + assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) + assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET)) + + // auth is set to false + conf.set(SecurityManager.SPARK_AUTH_CONF, "false") + cmd = CommandUtils invokePrivate buildLocalCommand( + command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) + assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) + assert(!cmd.environment.contains(SecurityManager.ENV_AUTH_SECRET)) + + // auth is set to true + conf.set(SecurityManager.SPARK_AUTH_CONF, "true") + cmd = CommandUtils invokePrivate buildLocalCommand( + command, new SecurityManager(conf), (t: String) => t, Seq(), Map()) + assert(!cmd.javaOpts.exists(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) + assert(cmd.environment(SecurityManager.ENV_AUTH_SECRET) === secret) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 3da992788962b..bed6f3ea61241 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -22,19 +22,20 @@ import java.io.File import scala.collection.JavaConversions._ import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { val appId = "12345-worker321-9876" + val conf = new SparkConf val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, - "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), + "publicAddr", new File(sparkHome), new File("ooga"), "blah", conf, Seq("localDir"), ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( - appDesc.command, 512, sparkHome, er.substituteVariables) + appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 47b2868753c0e..6bc45f249f975 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -490,8 +490,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -501,7 +501,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -517,8 +517,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -560,18 +560,18 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( - taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -784,6 +784,37 @@ class DAGSchedulerSuite assert(sc.parallelize(1 to 10, 2).first() === 1) } + test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[DAGSchedulerSuiteDummyException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPartitions: Array[Partition] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.reduceByKey(_ + _, 1).count() + } + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + + test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[SparkException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPreferredLocations(split: Partition): Seq[String] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.count() + } + assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + test("accumulator not calculated for resubmitted result stage") { // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) @@ -800,6 +831,50 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("reduce tasks should be placed locally with map output") { + // Create an shuffleMapRdd with 1 partition + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"))) + + // Reducer should run on the same host that map task ran + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostA"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("reduce task locality preferences should only include machines with largest map outputs") { + val numMapTasks = 4 + // Create an shuffleMapRdd with more partitions + val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + val statuses = (1 to numMapTasks).map { i => + (Success, makeMapStatus("host" + i, 1, (10*i).toByte)) + } + complete(taskSets(0), statuses) + + // Reducer should prefer the last 3 hosts as they have 20%, 30% and 40% of data + val hosts = (1 to numMapTasks).map(i => "host" + i).reverse.take(numMapTasks - 1) + + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(hosts)) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -807,12 +882,12 @@ class DAGSchedulerSuite private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { assert(hosts.size === taskSet.tasks.size) for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { - assert(taskLocs.map(_.host) === expectedLocs) + assert(taskLocs.map(_.host).toSet === expectedLocs.toSet) } } - private def makeMapStatus(host: String, reduces: Int): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2)) + private def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 2707bb53bc383..2d5e9d66b2e15 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{ObjectOutput, ObjectInput} +import java.io._ import org.scalatest.BeforeAndAfterEach @@ -98,7 +98,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { } test("externalizable class writing out not serializable object") { - val s = find(new ExternalizableClass) + val s = find(new ExternalizableClass(new SerializableClass2(new NotSerializable))) assert(s.size === 5) assert(s(0).contains("NotSerializable")) assert(s(1).contains("objectField")) @@ -106,6 +106,93 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { assert(s(3).contains("writeExternal")) assert(s(4).contains("ExternalizableClass")) } + + test("externalizable class writing out serializable objects") { + assert(find(new ExternalizableClass(new SerializableClass1)).isEmpty) + } + + test("object containing writeReplace() which returns not serializable object") { + val s = find(new SerializableClassWithWriteReplace(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("writeReplace")) + assert(s(2).contains("SerializableClassWithWriteReplace")) + } + + test("object containing writeReplace() which returns serializable object") { + assert(find(new SerializableClassWithWriteReplace(new SerializableClass1)).isEmpty) + } + + test("object containing writeObject() and not serializable field") { + val s = find(new SerializableClassWithWriteObject(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("writeObject data")) + assert(s(2).contains("SerializableClassWithWriteObject")) + } + + test("object containing writeObject() and serializable field") { + assert(find(new SerializableClassWithWriteObject(new SerializableClass1)).isEmpty) + } + + test("object of serializable subclass with more fields than superclass (SPARK-7180)") { + // This should not throw ArrayOutOfBoundsException + find(new SerializableSubclass(new SerializableClass1)) + } + + test("crazy nested objects") { + + def findAndAssert(shouldSerialize: Boolean, obj: Any): Unit = { + val s = find(obj) + if (shouldSerialize) { + assert(s.isEmpty) + } else { + assert(s.nonEmpty) + assert(s.head.contains("NotSerializable")) + } + } + + findAndAssert(false, + new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass( + new SerializableArray( + Array(new SerializableClass1, new SerializableClass2(new NotSerializable)) + ) + ))) + ) + + findAndAssert(true, + new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass( + new SerializableArray( + Array(new SerializableClass1, new SerializableClass2(new SerializableClass1)) + ) + ))) + ) + } + + test("improveException") { + val e = SerializationDebugger.improveException( + new SerializableClass2(new NotSerializable), new NotSerializableException("someClass")) + assert(e.getMessage.contains("someClass")) // original exception message should be present + assert(e.getMessage.contains("SerializableClass2")) // found debug trace should be present + } + + test("improveException with error in debugger") { + // Object that throws exception in the SerializationDebugger + val o = new SerializableClass1 { + private def writeReplace(): Object = { + throw new Exception() + } + } + withClue("requirement: SerializationDebugger should fail trying debug this object") { + intercept[Exception] { + SerializationDebugger.find(o) + } + } + + val originalException = new NotSerializableException("someClass") + // verify thaht original exception is returned on failure + assert(SerializationDebugger.improveException(o, originalException).eq(originalException)) + } } @@ -118,10 +205,34 @@ class SerializableClass2(val objectField: Object) extends Serializable class SerializableArray(val arrayField: Array[Object]) extends Serializable -class ExternalizableClass extends java.io.Externalizable { +class SerializableSubclass(val objectField: Object) extends SerializableClass1 + + +class SerializableClassWithWriteObject(val objectField: Object) extends Serializable { + val serializableObjectField = new SerializableClass1 + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = { + oos.defaultWriteObject() + } +} + + +class SerializableClassWithWriteReplace(@transient replacementFieldObject: Object) + extends Serializable { + private def writeReplace(): Object = { + replacementFieldObject + } +} + + +class ExternalizableClass(objectField: Object) extends java.io.Externalizable { + val serializableObjectField = new SerializableClass1 + override def writeExternal(out: ObjectOutput): Unit = { out.writeInt(1) - out.writeObject(new SerializableClass2(new NotSerializable)) + out.writeObject(serializableObjectField) + out.writeObject(objectField) } override def readExternal(in: ObjectInput): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..28ca68698e3dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala index a73e94e05575e..6727934d8c7ca 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -76,6 +76,15 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // Shuffles with key orderings are supported as long as no aggregator is specified + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = None, + mapSideCombine = false + ))) + } test("unsupported shuffle dependencies") { @@ -100,14 +109,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform any kind of aggregation or sorting of keys - assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = Some(mock(classOf[Ordering[Any]])), - aggregator = None, - mapSideCombine = false - ))) + // We do not support shuffles that perform aggregation assert(!canUseUnsafeShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, @@ -115,7 +117,6 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), mapSideCombine = false ))) - // We do not support shuffles that perform any kind of aggregation or sorting of keys assert(!canUseUnsafeShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..9ced4148d7206 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { + +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks @@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 33712f1bfa782..3aa672f8b713c 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -23,6 +23,7 @@ import javax.servlet.http.{HttpServletResponse, HttpServletRequest} import scala.collection.JavaConversions._ import scala.xml.Node +import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler import org.json4s._ import org.json4s.jackson.JsonMethods import org.openqa.selenium.htmlunit.HtmlUnitDriver @@ -31,6 +32,7 @@ import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.w3c.css.sac.CSSParseException import org.apache.spark.LocalSparkContext._ import org.apache.spark._ @@ -39,6 +41,31 @@ import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} +private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { + + private val cssWhiteList = List("bootstrap.min.css", "vis.min.css") + + private def isInWhileList(uri: String): Boolean = cssWhiteList.exists(uri.endsWith) + + override def warning(e: CSSParseException): Unit = { + if (!isInWhileList(e.getURI)) { + super.warning(e) + } + } + + override def fatalError(e: CSSParseException): Unit = { + if (!isInWhileList(e.getURI)) { + super.fatalError(e) + } + } + + override def error(e: CSSParseException): Unit = { + if (!isInWhileList(e.getURI)) { + super.error(e) + } + } +} + /** * Selenium tests for the Spark Web UI. */ @@ -49,7 +76,9 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } } override def afterAll(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 70cd27b04347d..1053c6caf7718 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,6 +121,10 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) } } } + + test("createNullValue") { + new TestCreateNullValue().run() + } } // A non-serializable class we create in closures to make sure that we aren't @@ -350,3 +354,43 @@ private object TestUserClosuresActuallyCleaned { ) } } + +class TestCreateNullValue { + + var x = 5 + + def getX: Int = x + + def run(): Unit = { + val bo: Boolean = true + val c: Char = '1' + val b: Byte = 1 + val s: Short = 1 + val i: Int = 1 + val l: Long = 1 + val f: Float = 1 + val d: Double = 1 + + // Bring in all primitive types into the closure such that they become + // parameters of the closure constructor. This allows us to test whether + // null values are created correctly for each type. + val nestedClosure = () => { + if (s.toString == "123") { // Don't really output them to avoid noisy + println(bo) + println(c) + println(b) + println(s) + println(i) + println(l) + println(f) + println(d) + } + + val closure = () => { + println(getX) + } + ClosureCleaner.clean(closure) + } + nestedClosure() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..baa4c661cc21e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -673,4 +673,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + stream.println("test circular test circular test circular test circular test circular") + assert(buffer.toString === "t circular test circular\n") + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 94e011799921b..3066e9996abda 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -44,7 +44,7 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { val goodMap3 = new OpenHashMap[String, String](256) assert(goodMap3.size === 0) intercept[IllegalArgumentException] { - new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29 + new OpenHashMap[String, Int](1 << 30 + 1) // Invalid map size: bigger than 2^30 } intercept[IllegalArgumentException] { new OpenHashMap[String, Int](-1) @@ -186,4 +186,14 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { map(null) = 0 assert(map.contains(null)) } + + test("support for more than 12M items") { + val cnt = 12000000 // 12M + val map = new OpenHashMap[Int, Int](cnt) + for (i <- 0 until cnt) { + map(i) = 1 + } + val numInvalidValues = map.iterator.count(_._2 == 0) + assertResult(0)(numInvalidValues) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 462bc2f29f9f8..508e737b725bc 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -44,7 +44,7 @@ class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) assert(goodMap3.size === 0) intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30 + 1) // Invalid map size: bigger than 2^30 } intercept[IllegalArgumentException] { new PrimitiveKeyOpenHashMap[Int, Int](-1) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 0a599b5a65549..5f2671a6e5053 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -91,3 +91,41 @@ zapletal-martin - Martin Zapletal zuxqoj - Shekhar Bansal mingyukim - Mingyu Kim sigmoidanalytics - Mayur Rustagi +AiHe - Ai He +BenFradet - Ben Fradet +FavioVazquez - Favio Vazquez +JaysonSunshine - Jayson Sunshine +Liuchang0812 - Liu Chang +Sephiroth-Lin - Sephiroth Lin +dobashim - Masaru Dobashi +ehnalis - Zoltan Zvara +emres - Emre Sevinc +gchen - Guancheng Chen +haiyangsea - Haiyang Sea +hlin09 - Hao Lin +hqzizania - Qian Huang +jeanlyn - Jean Lyn +jerluc - Jeremy A. Lucas +jrabary - Jaonary Rabarisoa +judynash - Judy Nash +kaka1992 - Chen Song +ksonj - Kalle Jepsen +kuromatsu-nobuyuki - Nobuyuki Kuromatsu +lazyman500 - Dong Xu +leahmcguire - Leah McGuire +mbittmann - Mark Bittmann +mbonaci - Marko Bonaci +meawoppl - Matthew Goodman +nyaapa - Arsenii Krasikov +phatak-dev - Madhukara Phatak +prabeesh - Prabeesh K +rakeshchalasani - Rakesh Chalasani +rekhajoshm - Rekha Joshi +sisihj - June He +szheng79 - Shuai Zheng +texasmichelle - Michelle Casbon +vinodkc - Vinod KC +yongtang - Yong Tang +ypcat - Pei-Lun Lee +zhichao-li - Zhichao Li +zzcclp - Zhichao Zhang diff --git a/dev/lint-python b/dev/lint-python index f50d149dc4d44..0c3586462cb37 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,8 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" cd "$SPARK_ROOT_DIR" diff --git a/dev/lint-r b/dev/lint-r new file mode 100755 index 0000000000000..7d5f4cd31153d --- /dev/null +++ b/dev/lint-r @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log" + + +if ! type "Rscript" > /dev/null; then + echo "ERROR: You should install R" + exit +fi + +`which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" diff --git a/dev/lint-r.R b/dev/lint-r.R new file mode 100644 index 0000000000000..dcb1a184291e1 --- /dev/null +++ b/dev/lint-r.R @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Installs lintr from Github. +# NOTE: The CRAN's version is too old to adapt to our rules. +if ("lintr" %in% row.names(installed.packages()) == FALSE) { + devtools::install_github("jimhester/lintr") +} +library(lintr) + +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + +path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") +lint_package(path.to.package, cache = FALSE) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cd83b352c1bfb..cf827ce89b857 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -431,6 +431,8 @@ def main(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) main() diff --git a/dev/run-tests b/dev/run-tests index d178e2a4601ea..257d1e8d50bb4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -17,224 +17,7 @@ # limitations under the License. # -# Go to the Spark project root directory FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -# Clean up work directory and caches -rm -rf ./work -rm -rf ~/.ivy2/local/org.apache.spark -rm -rf ~/.ivy2/cache/org.apache.spark - -source "$FWDIR/dev/run-tests-codes.sh" - -CURRENT_BLOCK=$BLOCK_GENERAL - -function handle_error () { - echo "[error] Got a return code of $? on line $1 of the run-tests script." - exit $CURRENT_BLOCK -} - - -# Build against the right version of Hadoop. -{ - if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then - if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=1.2.1" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" - fi - fi - - if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" - fi -} - -export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" - -# Determine Java path and version. -{ - if test -x "$JAVA_HOME/bin/java"; then - declare java_cmd="$JAVA_HOME/bin/java" - else - declare java_cmd=java - fi - - # We can't use sed -r -e due to OS X / BSD compatibility; hence, all the parentheses. - JAVA_VERSION=$( - $java_cmd -version 2>&1 \ - | grep -e "^java version" --max-count=1 \ - | sed "s/java version \"\(.*\)\.\(.*\)\.\(.*\)\"/\1\2/" - ) - - if [ "$JAVA_VERSION" -lt 18 ]; then - echo "[warn] Java 8 tests will not run because JDK version is < 1.8." - fi -} - -# Only run Hive tests if there are SQL changes. -# Partial solution for SPARK-1455. -if [ -n "$AMPLAB_JENKINS" ]; then - target_branch="$ghprbTargetBranch" - git fetch origin "$target_branch":"$target_branch" - - # AMP_JENKINS_PRB indicates if the current build is a pull request build. - if [ -n "$AMP_JENKINS_PRB" ]; then - # It is a pull request build. - sql_diffs=$( - git diff --name-only "$target_branch" \ - | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" - ) - - non_sql_diffs=$( - git diff --name-only "$target_branch" \ - | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" - ) - - if [ -n "$sql_diffs" ]; then - echo "[info] Detected changes in SQL. Will run Hive test suite." - _RUN_SQL_TESTS=true - - if [ -z "$non_sql_diffs" ]; then - echo "[info] Detected no changes except in SQL. Will only run SQL tests." - _SQL_TESTS_ONLY=true - fi - fi - else - # It is a regular build. We should run SQL tests. - _RUN_SQL_TESTS=true - fi -fi - -set -o pipefail -trap 'handle_error $LINENO' ERR - -echo "" -echo "=========================================================================" -echo "Running Apache RAT checks" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_RAT - -./dev/check-license - -echo "" -echo "=========================================================================" -echo "Running Scala style checks" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_SCALA_STYLE - -./dev/lint-scala - -echo "" -echo "=========================================================================" -echo "Running Python style checks" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_PYTHON_STYLE - -./dev/lint-python - -echo "" -echo "=========================================================================" -echo "Building Spark" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_BUILD - -{ - HIVE_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" - echo "[info] Compile with Hive 0.13.1" - [ -d "lib_managed" ] && rm -rf lib_managed - echo "[info] Building Spark with these arguments: $HIVE_BUILD_ARGS" - - if [ "${AMPLAB_JENKINS_BUILD_TOOL}" == "maven" ]; then - build/mvn $HIVE_BUILD_ARGS clean package -DskipTests - else - echo -e "q\n" \ - | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \ - | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" - fi -} - -echo "" -echo "=========================================================================" -echo "Detecting binary incompatibilities with MiMa" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_MIMA - -./dev/mima - -echo "" -echo "=========================================================================" -echo "Running Spark unit tests" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS - -{ - # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. - # This must be a single argument, as it is. - if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" - fi - - if [ -n "$_SQL_TESTS_ONLY" ]; then - # This must be an array of individual arguments. Otherwise, having one long string - # will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test") - else - SBT_MAVEN_TEST_ARGS=("test") - fi - - echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}" - - if [ "${AMPLAB_JENKINS_BUILD_TOOL}" == "maven" ]; then - build/mvn test $SBT_MAVEN_PROFILES_ARGS --fail-at-end - else - # NOTE: echo "q" is needed because sbt on encountering a build file with failure - # (either resolution or compilation) prompts the user for input either q, r, etc - # to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a - # single argument! - # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array. - # QUESTION: Why doesn't 'yes "q"' work? - # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? - echo -e "q\n" \ - | build/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ - | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" - fi -} - -echo "" -echo "=========================================================================" -echo "Running PySpark tests" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS - -# add path for python 3 in jenkins -export PATH="${PATH}:/home/anaconda/envs/py3k/bin" -./python/run-tests - -echo "" -echo "=========================================================================" -echo "Running SparkR tests" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS - -if [ $(command -v R) ]; then - ./R/install-dev.sh - ./R/run-tests.sh -else - echo "Ignoring SparkR tests as R was not found in PATH" -fi - +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 154e01255b2ef..f4b238e1b78a7 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -21,8 +21,9 @@ readonly BLOCK_GENERAL=10 readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_BUILD=14 -readonly BLOCK_MIMA=15 -readonly BLOCK_SPARK_UNIT_TESTS=16 -readonly BLOCK_PYSPARK_UNIT_TESTS=17 -readonly BLOCK_SPARKR_UNIT_TESTS=18 +readonly BLOCK_DOCUMENTATION=14 +readonly BLOCK_BUILD=15 +readonly BLOCK_MIMA=16 +readonly BLOCK_SPARK_UNIT_TESTS=17 +readonly BLOCK_PYSPARK_UNIT_TESTS=18 +readonly BLOCK_SPARKR_UNIT_TESTS=19 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 641b0ff3c4be4..c4d39d95d5890 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Scala style tests" elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then + failing_test="to generate documentation" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then failing_test="to build" elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py new file mode 100755 index 0000000000000..4596e07014733 --- /dev/null +++ b/dev/run-tests.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python2 + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import itertools +from optparse import OptionParser +import os +import re +import sys +import subprocess +from collections import namedtuple + +from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +import sparktestsupport.modules as modules + +# ------------------------------------------------------------------------------------------------- +# Functions for traversing module dependency graph +# ------------------------------------------------------------------------------------------------- + + +def determine_modules_for_files(filenames): + """ + Given a list of filenames, return the set of modules that contain those files. + If a file is not associated with a more specific submodule, then this method will consider that + file to belong to the 'root' module. + + >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) + ['pyspark-core', 'sql'] + >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] + ['root'] + """ + changed_modules = set() + for filename in filenames: + matched_at_least_one_module = False + for module in modules.all_modules: + if module.contains_file(filename): + changed_modules.add(module) + matched_at_least_one_module = True + if not matched_at_least_one_module: + changed_modules.add(modules.root) + return changed_modules + + +def identify_changed_files_from_git_commits(patch_sha, target_branch=None, target_ref=None): + """ + Given a git commit and target ref, use the set of files changed in the diff in order to + determine which modules' tests should be run. + + >>> [x.name for x in determine_modules_for_files( \ + identify_changed_files_from_git_commits("fc0a1475ef", target_ref="5da21f07"))] + ['graphx'] + >>> 'root' in [x.name for x in determine_modules_for_files( \ + identify_changed_files_from_git_commits("50a0496a43", target_ref="6765ef9"))] + True + """ + if target_branch is None and target_ref is None: + raise AttributeError("must specify either target_branch or target_ref") + elif target_branch is not None and target_ref is not None: + raise AttributeError("must specify either target_branch or target_ref, not both") + if target_branch is not None: + diff_target = target_branch + run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) + else: + diff_target = target_ref + raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target], + universal_newlines=True) + # Remove any empty strings + return [f for f in raw_output.split('\n') if f] + + +def determine_modules_to_test(changed_modules): + """ + Given a set of modules that have changed, compute the transitive closure of those modules' + dependent modules in order to determine the set of modules that should be tested. + + >>> sorted(x.name for x in determine_modules_to_test([modules.root])) + ['root'] + >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) + ['examples', 'graphx'] + >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> x # doctest: +NORMALIZE_WHITESPACE + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] + """ + # If we're going to have to run all of the tests, then we can just short-circuit + # and return 'root'. No module depends on root, so if it appears then it will be + # in changed_modules. + if modules.root in changed_modules: + return [modules.root] + modules_to_test = set() + for module in changed_modules: + modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) + return modules_to_test.union(set(changed_modules)) + + +# ------------------------------------------------------------------------------------------------- +# Functions for working with subprocesses and shell tools +# ------------------------------------------------------------------------------------------------- + +def get_error_codes(err_code_file): + """Function to retrieve all block numbers from the `run-tests-codes.sh` + file to maintain backwards compatibility with the `run-tests-jenkins` + script""" + + with open(err_code_file, 'r') as f: + err_codes = [e.split()[1].strip().split('=') + for e in f if e.startswith("readonly")] + return dict(err_codes) + + +ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) + + +def determine_java_executable(): + """Will return the path of the java executable that will be used by Spark's + tests or `None`""" + + # Any changes in the way that Spark's build detects java must be reflected + # here. Currently the build looks for $JAVA_HOME/bin/java then falls back to + # the `java` executable on the path + + java_home = os.environ.get("JAVA_HOME") + + # check if there is an executable at $JAVA_HOME/bin/java + java_exe = which(os.path.join(java_home, "bin", "java")) if java_home else None + # if the java_exe wasn't set, check for a `java` version on the $PATH + return java_exe if java_exe else which("java") + + +JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch', 'update']) + + +def determine_java_version(java_exe): + """Given a valid java executable will return its version in named tuple format + with accessors '.major', '.minor', '.patch', '.update'""" + + raw_output = subprocess.check_output([java_exe, "-version"], + stderr=subprocess.STDOUT, + universal_newlines=True) + + raw_output_lines = raw_output.split('\n') + + # find raw version string, eg 'java version "1.8.0_25"' + raw_version_str = next(x for x in raw_output_lines if " version " in x) + + version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' + version, update = version_str.split('_') # eg ['1.8.0', '25'] + + # map over the values and convert them to integers + version_info = [int(x) for x in version.split('.') + [update]] + + return JavaVersion(major=version_info[0], + minor=version_info[1], + patch=version_info[2], + update=version_info[3]) + + +# ------------------------------------------------------------------------------------------------- +# Functions for running the other build and test scripts +# ------------------------------------------------------------------------------------------------- + + +def set_title_and_block(title, err_block): + os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] + line_str = '=' * 72 + + print('') + print(line_str) + print(title) + print(line_str) + + +def run_apache_rat_checks(): + set_title_and_block("Running Apache RAT checks", "BLOCK_RAT") + run_cmd([os.path.join(SPARK_HOME, "dev", "check-license")]) + + +def run_scala_style_checks(): + set_title_and_block("Running Scala style checks", "BLOCK_SCALA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) + + +def run_python_style_checks(): + set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) + + +def build_spark_documentation(): + set_title_and_block("Building Spark Documentation", "BLOCK_DOCUMENTATION") + os.environ["PRODUCTION"] = "1 jekyll build" + + os.chdir(os.path.join(SPARK_HOME, "docs")) + + jekyll_bin = which("jekyll") + + if not jekyll_bin: + print("[error] Cannot find a version of `jekyll` on the system; please", + " install one and retry to build documentation.") + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + else: + run_cmd([jekyll_bin, "build"]) + + os.chdir(SPARK_HOME) + + +def exec_maven(mvn_args=()): + """Will call Maven in the current directory with the list of mvn_args passed + in and returns the subprocess for any further processing""" + + run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) + + +def exec_sbt(sbt_args=()): + """Will call SBT in the current directory with the list of mvn_args passed + in and returns the subprocess for any further processing""" + + sbt_cmd = [os.path.join(SPARK_HOME, "build", "sbt")] + sbt_args + + sbt_output_filter = re.compile("^.*[info].*Resolving" + "|" + + "^.*[warn].*Merging" + "|" + + "^.*[info].*Including") + + # NOTE: echo "q" is needed because sbt on encountering a build file + # with failure (either resolution or compilation) prompts the user for + # input either q, r, etc to quit or retry. This echo is there to make it + # not block. + echo_proc = subprocess.Popen(["echo", "\"q\n\""], stdout=subprocess.PIPE) + sbt_proc = subprocess.Popen(sbt_cmd, + stdin=echo_proc.stdout, + stdout=subprocess.PIPE) + echo_proc.wait() + for line in iter(sbt_proc.stdout.readline, ''): + if not sbt_output_filter.match(line): + print(line, end='') + retcode = sbt_proc.wait() + + if retcode > 0: + exit_from_command_with_retcode(sbt_cmd, retcode) + + +def get_hadoop_profiles(hadoop_version): + """ + For the given Hadoop version tag, return a list of SBT profile flags for + building and testing against that Hadoop version. + """ + + sbt_maven_hadoop_profiles = { + "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], + "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], + "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], + "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + } + + if hadoop_version in sbt_maven_hadoop_profiles: + return sbt_maven_hadoop_profiles[hadoop_version] + else: + print("[error] Could not find", hadoop_version, "in the list. Valid options", + " are", sbt_maven_hadoop_profiles.keys()) + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + + +def build_spark_maven(hadoop_version): + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + mvn_goals = ["clean", "package", "-DskipTests"] + profiles_and_goals = build_profiles + mvn_goals + + print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: ", + " ".join(profiles_and_goals)) + + exec_maven(profiles_and_goals) + + +def build_spark_sbt(hadoop_version): + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["package", + "assembly/assembly", + "streaming-kafka-assembly/assembly"] + profiles_and_goals = build_profiles + sbt_goals + + print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + +def build_apache_spark(build_tool, hadoop_version): + """Will build Spark against Hive v0.13.1 given the passed in build tool (either `sbt` or + `maven`). Defaults to using `sbt`.""" + + set_title_and_block("Building Spark", "BLOCK_BUILD") + + rm_r("lib_managed") + + if build_tool == "maven": + build_spark_maven(hadoop_version) + else: + build_spark_sbt(hadoop_version) + + +def detect_binary_inop_with_mima(): + set_title_and_block("Detecting binary incompatibilities with MiMa", "BLOCK_MIMA") + run_cmd([os.path.join(SPARK_HOME, "dev", "mima")]) + + +def run_scala_tests_maven(test_profiles): + mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals + + print("[info] Running Spark tests using Maven with these arguments: ", + " ".join(profiles_and_goals)) + + exec_maven(profiles_and_goals) + + +def run_scala_tests_sbt(test_modules, test_profiles): + + sbt_test_goals = set(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) + + if not sbt_test_goals: + return + + profiles_and_goals = test_profiles + list(sbt_test_goals) + + print("[info] Running Spark tests using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + +def run_scala_tests(build_tool, hadoop_version, test_modules): + """Function to properly execute all tests passed in as a set from the + `determine_test_suites` function""" + set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") + + test_modules = set(test_modules) + + test_profiles = get_hadoop_profiles(hadoop_version) + \ + list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + if build_tool == "maven": + run_scala_tests_maven(test_profiles) + else: + run_scala_tests_sbt(test_modules, test_profiles) + + +def run_python_tests(test_modules, parallelism): + set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") + + command = [os.path.join(SPARK_HOME, "python", "run-tests")] + if test_modules != [modules.root]: + command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) + run_cmd(command) + + +def run_sparkr_tests(): + set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") + + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) + run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) + else: + print("Ignoring SparkR tests as R was not found in PATH") + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + +def main(): + opts = parse_opts() + # Ensure the user home directory (HOME) is valid and is an absolute directory + if not USER_HOME or not os.path.isabs(USER_HOME): + print("[error] Cannot determine your home directory as an absolute path;", + " ensure the $HOME environment variable is set properly.") + sys.exit(1) + + os.chdir(SPARK_HOME) + + rm_r(os.path.join(SPARK_HOME, "work")) + rm_r(os.path.join(USER_HOME, ".ivy2", "local", "org.apache.spark")) + rm_r(os.path.join(USER_HOME, ".ivy2", "cache", "org.apache.spark")) + + os.environ["CURRENT_BLOCK"] = ERROR_CODES["BLOCK_GENERAL"] + + java_exe = determine_java_executable() + + if not java_exe: + print("[error] Cannot find a version of `java` on the system; please", + " install one and retry.") + sys.exit(2) + + java_version = determine_java_version(java_exe) + + if java_version.minor < 8: + print("[warn] Java 8 tests will not run because JDK version is < 1.8.") + + if os.environ.get("AMPLAB_JENKINS"): + # if we're on the Amplab Jenkins build servers setup variables + # to reflect the environment settings + build_tool = os.environ.get("AMPLAB_JENKINS_BUILD_TOOL", "sbt") + hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop2.3") + test_env = "amplab_jenkins" + # add path for Python3 in Jenkins if we're calling from a Jenkins machine + os.environ["PATH"] = "/home/anaconda/envs/py3k/bin:" + os.environ.get("PATH") + else: + # else we're running locally and can use local settings + build_tool = "sbt" + hadoop_version = "hadoop2.3" + test_env = "local" + + print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, + "under environment", test_env) + + changed_modules = None + changed_files = None + if test_env == "amplab_jenkins" and os.environ.get("AMP_JENKINS_PRB"): + target_branch = os.environ["ghprbTargetBranch"] + changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) + changed_modules = determine_modules_for_files(changed_files) + if not changed_modules: + changed_modules = [modules.root] + print("[info] Found the following changed modules:", + ", ".join(x.name for x in changed_modules)) + + test_modules = determine_modules_to_test(changed_modules) + + # license checks + run_apache_rat_checks() + + # style checks + if not changed_files or any(f.endswith(".scala") for f in changed_files): + run_scala_style_checks() + if not changed_files or any(f.endswith(".py") for f in changed_files): + run_python_style_checks() + + # determine if docs were changed and if we're inside the amplab environment + # note - the below commented out until *all* Jenkins workers can get `jekyll` installed + # if "DOCS" in changed_modules and test_env == "amplab_jenkins": + # build_spark_documentation() + + # spark build + build_apache_spark(build_tool, hadoop_version) + + # backwards compatibility checks + detect_binary_inop_with_mima() + + # run the test suites + run_scala_tests(build_tool, hadoop_version, test_modules) + + modules_with_python_tests = [m for m in test_modules if m.python_test_goals] + if modules_with_python_tests: + run_python_tests(modules_with_python_tests, opts.parallelism) + if any(m.should_run_r_tests for m in test_modules): + run_sparkr_tests() + + +def _test(): + import doctest + failure_count = doctest.testmod()[0] + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() + main() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py new file mode 100644 index 0000000000000..12696d98fb988 --- /dev/null +++ b/dev/sparktestsupport/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) +USER_HOME = os.environ.get("HOME") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py new file mode 100644 index 0000000000000..efe3a897e9c10 --- /dev/null +++ b/dev/sparktestsupport/modules.py @@ -0,0 +1,385 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import itertools +import re + +all_modules = [] + + +class Module(object): + """ + A module is the basic abstraction in our test runner script. Each module consists of a set of + source files, a set of test commands, and a set of dependencies on other modules. We use modules + to define a dependency graph that lets determine which tests to run based on which files have + changed. + """ + + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), + should_run_r_tests=False): + """ + Define a new module. + + :param name: A short module name, for display in logging and error messages. + :param dependencies: A set of dependencies for this module. This should only include direct + dependencies; transitive dependencies are resolved automatically. + :param source_file_regexes: a set of regexes that match source files belonging to this + module. These regexes are applied by attempting to match at the beginning of the + filename strings. + :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in + order to build and test this module (e.g. '-PprofileName'). + :param sbt_test_goals: A set of SBT test goals for testing this module. + :param python_test_goals: A set of Python test goals for testing this module. + :param blacklisted_python_implementations: A set of Python implementations that are not + supported by this module's Python components. The values in this set should match + strings returned by Python's `platform.python_implementation()`. + :param should_run_r_tests: If true, changes in this module will trigger all R tests. + """ + self.name = name + self.dependencies = dependencies + self.source_file_prefixes = source_file_regexes + self.sbt_test_goals = sbt_test_goals + self.build_profile_flags = build_profile_flags + self.python_test_goals = python_test_goals + self.blacklisted_python_implementations = blacklisted_python_implementations + self.should_run_r_tests = should_run_r_tests + + self.dependent_modules = set() + for dep in dependencies: + dep.dependent_modules.add(self) + all_modules.append(self) + + def contains_file(self, filename): + return any(re.match(p, filename) for p in self.source_file_prefixes) + + +sql = Module( + name="sql", + dependencies=[], + source_file_regexes=[ + "sql/(?!hive-thriftserver)", + "bin/spark-sql", + ], + build_profile_flags=[ + "-Phive", + ], + sbt_test_goals=[ + "catalyst/test", + "sql/test", + "hive/test", + ] +) + + +hive_thriftserver = Module( + name="hive-thriftserver", + dependencies=[sql], + source_file_regexes=[ + "sql/hive-thriftserver", + "sbin/start-thriftserver.sh", + ], + build_profile_flags=[ + "-Phive-thriftserver", + ], + sbt_test_goals=[ + "hive-thriftserver/test", + ] +) + + +graphx = Module( + name="graphx", + dependencies=[], + source_file_regexes=[ + "graphx/", + ], + sbt_test_goals=[ + "graphx/test" + ] +) + + +streaming = Module( + name="streaming", + dependencies=[], + source_file_regexes=[ + "streaming", + ], + sbt_test_goals=[ + "streaming/test", + ] +) + + +streaming_kinesis_asl = Module( + name="kinesis-asl", + dependencies=[streaming], + source_file_regexes=[ + "extras/kinesis-asl/", + ], + build_profile_flags=[ + "-Pkinesis-asl", + ], + sbt_test_goals=[ + "kinesis-asl/test", + ] +) + + +streaming_zeromq = Module( + name="streaming-zeromq", + dependencies=[streaming], + source_file_regexes=[ + "external/zeromq", + ], + sbt_test_goals=[ + "streaming-zeromq/test", + ] +) + + +streaming_twitter = Module( + name="streaming-twitter", + dependencies=[streaming], + source_file_regexes=[ + "external/twitter", + ], + sbt_test_goals=[ + "streaming-twitter/test", + ] +) + + +streaming_mqtt = Module( + name="streaming-mqtt", + dependencies=[streaming], + source_file_regexes=[ + "external/mqtt", + ], + sbt_test_goals=[ + "streaming-mqtt/test", + ] +) + + +streaming_kafka = Module( + name="streaming-kafka", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka", + "external/kafka-assembly", + ], + sbt_test_goals=[ + "streaming-kafka/test", + ] +) + + +streaming_flume_sink = Module( + name="streaming-flume-sink", + dependencies=[streaming], + source_file_regexes=[ + "external/flume-sink", + ], + sbt_test_goals=[ + "streaming-flume-sink/test", + ] +) + + +streaming_flume = Module( + name="streaming_flume", + dependencies=[streaming], + source_file_regexes=[ + "external/flume", + ], + sbt_test_goals=[ + "streaming-flume/test", + ] +) + + +mllib = Module( + name="mllib", + dependencies=[streaming, sql], + source_file_regexes=[ + "data/mllib/", + "mllib/", + ], + sbt_test_goals=[ + "mllib/test", + ] +) + + +examples = Module( + name="examples", + dependencies=[graphx, mllib, streaming, sql], + source_file_regexes=[ + "examples/", + ], + sbt_test_goals=[ + "examples/test", + ] +) + + +pyspark_core = Module( + name="pyspark-core", + dependencies=[mllib, streaming, streaming_kafka], + source_file_regexes=[ + "python/(?!pyspark/(ml|mllib|sql|streaming))" + ], + python_test_goals=[ + "pyspark.rdd", + "pyspark.context", + "pyspark.conf", + "pyspark.broadcast", + "pyspark.accumulators", + "pyspark.serializers", + "pyspark.profiler", + "pyspark.shuffle", + "pyspark.tests", + ] +) + + +pyspark_sql = Module( + name="pyspark-sql", + dependencies=[pyspark_core, sql], + source_file_regexes=[ + "python/pyspark/sql" + ], + python_test_goals=[ + "pyspark.sql.types", + "pyspark.sql.context", + "pyspark.sql.column", + "pyspark.sql.dataframe", + "pyspark.sql.group", + "pyspark.sql.functions", + "pyspark.sql.readwriter", + "pyspark.sql.window", + "pyspark.sql.tests", + ] +) + + +pyspark_streaming = Module( + name="pyspark-streaming", + dependencies=[pyspark_core, streaming, streaming_kafka], + source_file_regexes=[ + "python/pyspark/streaming" + ], + python_test_goals=[ + "pyspark.streaming.util", + "pyspark.streaming.tests", + ] +) + + +pyspark_mllib = Module( + name="pyspark-mllib", + dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib], + source_file_regexes=[ + "python/pyspark/mllib" + ], + python_test_goals=[ + "pyspark.mllib.classification", + "pyspark.mllib.clustering", + "pyspark.mllib.evaluation", + "pyspark.mllib.feature", + "pyspark.mllib.fpm", + "pyspark.mllib.linalg", + "pyspark.mllib.random", + "pyspark.mllib.recommendation", + "pyspark.mllib.regression", + "pyspark.mllib.stat._statistics", + "pyspark.mllib.stat.KernelDensity", + "pyspark.mllib.tree", + "pyspark.mllib.util", + "pyspark.mllib.tests", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + + +pyspark_ml = Module( + name="pyspark-ml", + dependencies=[pyspark_core, pyspark_mllib], + source_file_regexes=[ + "python/pyspark/ml/" + ], + python_test_goals=[ + "pyspark.ml.feature", + "pyspark.ml.classification", + "pyspark.ml.recommendation", + "pyspark.ml.regression", + "pyspark.ml.tuning", + "pyspark.ml.tests", + "pyspark.ml.evaluation", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + +sparkr = Module( + name="sparkr", + dependencies=[sql, mllib], + source_file_regexes=[ + "R/", + ], + should_run_r_tests=True +) + + +docs = Module( + name="docs", + dependencies=[], + source_file_regexes=[ + "docs/", + ] +) + + +ec2 = Module( + name="ec2", + dependencies=[], + source_file_regexes=[ + "ec2/", + ] +) + + +# The root module is a dummy module which is used to run all of the tests. +# No other modules should directly depend on this module. +root = Module( + name="root", + dependencies=[], + source_file_regexes=[], + # In order to run all of the tests, enable every test profile: + build_profile_flags=list(set( + itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), + sbt_test_goals=[ + "test", + ], + python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), + should_run_r_tests=True +) diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py new file mode 100644 index 0000000000000..12bd0bf3a4fe9 --- /dev/null +++ b/dev/sparktestsupport/shellutils.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import os +import shutil +import subprocess +import sys + + +def exit_from_command_with_retcode(cmd, retcode): + print("[error] running", ' '.join(cmd), "; received return code", retcode) + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + + +def rm_r(path): + """ + Given an arbitrary path, properly remove it with the correct Python construct if it exists. + From: http://stackoverflow.com/a/9559881 + """ + + if os.path.isdir(path): + shutil.rmtree(path) + elif os.path.exists(path): + os.remove(path) + + +def run_cmd(cmd): + """ + Given a command as a list of arguments will attempt to execute the command + and, on failure, print an error message and exit. + """ + + if not isinstance(cmd, list): + cmd = cmd.split() + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + exit_from_command_with_retcode(e.cmd, e.returncode) + + +def is_exe(path): + """ + Check if a given path is an executable file. + From: http://stackoverflow.com/a/377028 + """ + + return os.path.isfile(path) and os.access(path, os.X_OK) + + +def which(program): + """ + Find and return the given program by its absolute path or 'None' if the program cannot be found. + From: http://stackoverflow.com/a/377028 + """ + + fpath = os.path.split(program)[0] + + if fpath: + if is_exe(program): + return program + else: + for path in os.environ.get("PATH").split(os.pathsep): + path = path.strip('"') + exe_file = os.path.join(path, program) + if is_exe(exe_file): + return exe_file + return None diff --git a/docs/README.md b/docs/README.md index 5852f972a051d..d7652e921f7df 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,7 +28,7 @@ in some cases: $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory +Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index eebb3faf90fc0..b4952fe97ca0e 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -138,6 +138,7 @@

{{ page.title }}

+ diff --git a/docs/api.md b/docs/api.md index 45df77ac05f78..ae7d51c2aefbf 100644 --- a/docs/api.md +++ b/docs/api.md @@ -3,7 +3,7 @@ layout: global title: Spark API Documentation --- -Here you can API docs for Spark and its submodules. +Here you can read API docs for Spark and its submodules. - [Spark Scala API (Scaladoc)](api/scala/index.html) - [Spark Java API (Javadoc)](api/java/index.html) diff --git a/docs/configuration.md b/docs/configuration.md index 3960e7e78bde1..affcd21514d88 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -334,7 +334,7 @@ Apart from these, the following properties are also available, and may be useful Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, + `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by @@ -1495,6 +1495,18 @@ Apart from these, the following properties are also available, and may be useful +#### SparkR + + + + + + + +
Property NameDefaultMeaning
spark.r.numRBackendThreads2 + Number of threads used by RBackend to handle RPC calls from SparkR package. +
+ #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: diff --git a/docs/css/main.css b/docs/css/main.css index f6fe7d5f07da1..89305a7d3a358 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -146,3 +146,8 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { .MathJax .mi { color: inherit } .MathJax .mf { color: inherit } .MathJax .mh { color: inherit } + +/** + * AnchorJS (anchor links when hovering over headers) + */ +a.anchorjs-link:hover { text-decoration: none; } diff --git a/docs/hadoop-provided.md b/docs/hadoop-provided.md new file mode 100644 index 0000000000000..bbd26b343e2e6 --- /dev/null +++ b/docs/hadoop-provided.md @@ -0,0 +1,26 @@ +--- +layout: global +displayTitle: Using Spark's "Hadoop Free" Build +title: Using Spark's "Hadoop Free" Build +--- + +Spark uses Hadoop client libraries for HDFS and YARN. Starting in version Spark 1.4, the project packages "Hadoop free" builds that lets you more easily connect a single Spark binary to any Hadoop version. To use these builds, you need to modify `SPARK_DIST_CLASSPATH` to include Hadoop's package jars. The most convenient place to do this is by adding an entry in `conf/spark-env.sh`. + +This page describes how to connect Spark to Hadoop for different types of distributions. + +# Apache Hadoop +For Apache distributions, you can use Hadoop's 'classpath' command. For instance: + +{% highlight bash %} +### in conf/spark-env.sh ### + +# If 'hadoop' binary is on your PATH +export SPARK_DIST_CLASSPATH=$(hadoop classpath) + +# With explicit path to 'hadoop' binary +export SPARK_DIST_CLASSPATH=$(/path/to/hadoop/bin/hadoop classpath) + +# Passing a Hadoop configuration directory +export SPARK_DIST_CLASSPATH=$(hadoop --config /path/to/configs classpath) + +{% endhighlight %} diff --git a/docs/index.md b/docs/index.md index 7939657915fc9..d85cf12defefd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,9 +12,13 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page -contains Spark packages for many popular HDFS versions. If you'd like to build Spark from -scratch, visit [Building Spark](building-spark.html). +Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Users can also download a "Hadoop free" binary and run Spark with any Hadoop version +[by augmenting Spark's classpath](hadoop-provided.html). + +If you'd like to build Spark from +source, visit [Building Spark](building-spark.html). + Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, diff --git a/docs/js/main.js b/docs/js/main.js index f1a90e47e89a7..f5d66b16f7b21 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -68,38 +68,11 @@ function codeTabs() { }); } -function makeCollapsable(elt, accordionClass, accordionBodyId, title) { - $(elt).addClass("accordion-inner"); - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).parent().before( - '
' + - '' + - title + - '' + - '
' - ); -} - -// Enable "view solution" sections (for exercises) -function viewSolution() { - var counter = 0 - $("div.solution").each(function() { - var id = "solution_" + counter - makeCollapsable(this, "", id, - '' + - '' + "View Solution"); - counter++; - }); -} // A script to fix internal hash links because we have an overlapping top bar. // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 function maybeScrollToHash() { - console.log("HERE"); if (window.location.hash && $(window.location.hash).length) { - console.log("HERE2", $(window.location.hash), $(window.location.hash).offset().top); var newTop = $(window.location.hash).offset().top - 57; $(window).scrollTop(newTop); } @@ -107,7 +80,12 @@ function maybeScrollToHash() { $(function() { codeTabs(); - viewSolution(); + // Display anchor links when hovering over headers. For documentation of the + // configuration options, see the AnchorJS documentation. + anchors.options = { + placement: 'left' + }; + anchors.add(); $(window).bind('hashchange', function() { maybeScrollToHash(); diff --git a/docs/js/vendor/anchor.min.js b/docs/js/vendor/anchor.min.js new file mode 100755 index 0000000000000..68c3cb7073b6d --- /dev/null +++ b/docs/js/vendor/anchor.min.js @@ -0,0 +1,6 @@ +/*! + * AnchorJS - v1.1.1 - 2015-05-23 + * https://github.com/bryanbraun/anchorjs + * Copyright (c) 2015 Bryan Braun; Licensed MIT + */ +function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm Guides - -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -**Pipelines API Algorithm Guides** - -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Ensembles](ml-ensembles.html) - - # Code Examples This section gives code examples illustrating the functionality discussed above. @@ -783,6 +779,16 @@ Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not # Migration Guide +## From 1.3 to 1.4 + +Several major API changes occurred, including: +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes +Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. + +However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. + ## From 1.2 to 1.3 The main API changes are from Spark SQL. We list the most important changes here: diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 1b088969ddc25..dcaa3784be874 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -592,15 +592,55 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() +{% endhighlight %} +
+ +
+First we import the neccessary classes. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.clustering import StreamingKMeans {% endhighlight %} +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight python %} +model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} +
+ +
+ As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier (e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change! - - - - diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 4fe470a8de810..83e937635a55b 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -558,6 +558,29 @@ JavaRDD transformedData2 = data.map( } ); +{% endhighlight %} + + +
+{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import ElementwiseProduct + +# Load and parse the data +sc = SparkContext() +data = sc.textFile("data/mllib/kmeans_data.txt") +parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) + +# Create weight vector. +transformingVector = Vectors.dense([0.0, 1.0, 2.0]) +transformer = ElementwiseProduct(transformingVector) + +# Batch transform +transformedData = transformer.transform(parsedData) +# Single-row transform +transformedData2 = transformer.transform(parsedData.first()) + {% endhighlight %}
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 9fd9be0dd01b1..bcc066a185526 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
-[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an -[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. {% highlight scala %} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index de7d66fb2dedf..d2d1cc93fe006 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -7,7 +7,19 @@ description: MLlib machine learning library overview for Spark SPARK_VERSION_SHO MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: +filtering, dimensionality reduction, as well as underlying optimization primitives. +Guides for individual algorithms are listed below. + +The API is divided into 2 parts: + +* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. +* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. + +We list major functionality from both below, with links to detailed guides. + +# MLlib types, algorithms and utilities + +This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -49,8 +61,8 @@ and the migration guide below will explain all changes between releases. Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -58,7 +70,11 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. +More detailed guides for `spark.ml` include: + +* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts +* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API +* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API # Dependencies @@ -90,21 +106,14 @@ version 1.4 or newer. For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -## From 1.2 to 1.3 - -In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. - -* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. -* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. -* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: - * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. - * Variable `model` is no longer public. -* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: - * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) - * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. -* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. - So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. ## Previous Spark Versions diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 4de2d9491ac2b..8df68d81f3c78 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,22 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in diff --git a/docs/programming-guide.md b/docs/programming-guide.md index d5ff416fe89a4..ae712d62746f6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1144,9 +1144,11 @@ generate these on the reduce side. When data does not fit in memory Spark will s to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are not cleaned up from Spark's temporary storage until Spark is stopped, which means that -long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need -to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 96cf612c54fdd..de22ab557cacf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -189,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -206,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -258,85 +335,34 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Principal to be used to login to KDC, while running on secure HDFS. + + + + + + + + + +
spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
(none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
spark.yarn.config.gatewayPath(none) + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

+ The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

+ For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. +

spark.yarn.config.replacementPath(none) + See spark.yarn.config.gatewayPath. +
-# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 12d7d6e159bea..4f71fbc086cd0 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./sbin/start-slave.sh + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). diff --git a/docs/sparkr.md b/docs/sparkr.md index 4d82129921a37..095ea4308cfeb 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -27,9 +27,9 @@ All of the examples on this page use sample data included in R or the Spark dist
The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name -etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the -SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should -already be created for you. +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. {% highlight r %} sc <- sparkR.init() @@ -62,7 +62,16 @@ head(df) SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
+{% highlight r %} +sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cde5830c733e0..2786e3d2cd6bf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -22,7 +22,7 @@ The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark. All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
@@ -819,8 +819,8 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified -name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted -name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
@@ -828,7 +828,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.parquet") +df.select("name", "age").write.format("json").save("namesAndAges.json") {% endhighlight %}
@@ -975,7 +975,7 @@ schemaPeople.write().parquet("people.parquet"); // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.javaRDD().map(new Function() { @@ -995,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.read.parquet("people.parquet") +schemaPeople.write.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.write.parquet("people.parquet") +parquetFile = sqlContext.read.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -1036,6 +1036,15 @@ for (teenName in collect(teenNames)) {
+
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1054,12 +1063,12 @@ SELECT * FROM parquetTable
-### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For exmaple, we can store all our previously used +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1102,9 +1111,13 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging + +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end @@ -1121,12 +1134,12 @@ source is now able to automatically detect this case and merge schemas of all th import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory -val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column -val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table @@ -1134,7 +1147,7 @@ val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together -// with the partiioning column appeared in the partition directory paths. +// with the partitioning column appeared in the partition directory paths. // root // |-- single: int (nullable = true) // |-- double: int (nullable = true) @@ -1165,7 +1178,7 @@ df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1192,7 +1205,7 @@ df3 <- loadDF(sqlContext, "data/test_table", "parquet") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1204,6 +1217,79 @@ printSchema(df3)
+### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
+ +
+ +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
+ +
+ ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1216,7 +1302,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do - not differentiate between binary data and strings when writing out the Parquet schema. This + not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1233,7 +1319,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.cacheMetadata true - Turns on caching of Parquet schema metadata. Can speed up querying of static data. + Turns on caching of Parquet schema metadata. Can speed up querying of static data. @@ -1249,7 +1335,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). + bug in Parquet 1.6.0rc3 (PARQUET-136). However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn this feature on. @@ -1262,6 +1348,34 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + spark.sql.parquet.output.committer.class + org.apache.parquet.hadoop.
ParquetOutputCommitter
+ +

+ The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
mapreduce.OutputCommitter
. Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

+

+ Note: +

    +
  • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
  • +
  • + This option overrides spark.sql.sources.
    outputCommitterClass
    . +
  • +
+

+

+ Spark SQL comes with a builtin + org.apache.spark.sql.
parquet.DirectParquetOutputCommitter
, which can be more + efficient then the default Parquet output committer when writing data to S3. +

+ + ## JSON Datasets @@ -1398,7 +1512,7 @@ sqlContext <- sparkRSQL.init(sc) # The path can be either a single text file or a directory storing text files. path <- "examples/src/main/resources/people.json" # Create a DataFrame from the file(s) pointed to by path -people <- jsonFile(sqlContex,t path) +people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) @@ -1441,7 +1555,12 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. +
@@ -1470,12 +1589,12 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -1785,7 +1904,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - + spark.sql.planner.externalSort false @@ -1812,7 +1931,7 @@ To start the JDBC/ODBC server, run the following in the Spark directory: This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. By default, the server listens on localhost:10000. You may override this -bahaviour via either environment variables, i.e.: +behaviour via either environment variables, i.e.: {% highlight bash %} export HIVE_SERVER2_THRIFT_PORT= @@ -1880,7 +1999,7 @@ options. #### DataFrame data reader/writer interface Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) -and writing data out (`DataFrame.write`), +and writing data out (`DataFrame.write`), and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` ( @@ -2766,7 +2885,7 @@ from pyspark.sql.types import * MapType - enviroment + environment list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
Note: The default value of valueContainsNull is True. diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 6a2048121f8bf..a75587a92adc7 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). +the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. Note that custom receivers can be implemented @@ -21,15 +21,15 @@ A custom receiver must extend this abstract class by implementing two methods - `onStop()`: Things to do to stop receiving data. Both `onStart()` and `onStop()` must not block indefinitely. Typically, `onStart()` would start the threads -that responsible for receiving the data and `onStop()` would ensure that the receiving by those threads +that are responsible for receiving the data, and `onStop()` would ensure that these threads receiving the data are stopped. The receiving threads can also use `isStopped()`, a `Receiver` method, to check whether they should stop receiving data. Once the data is received, that data can be stored inside Spark by calling `store(data)`, which is a method provided by the Receiver class. -There are number of flavours of `store()` which allow you store the received data -record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavour of -`store()` used to implemented a receiver affects its reliability and fault-tolerance semantics. +There are a number of flavors of `store()` which allow one to store the received data +record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavor of +`store()` used to implement a receiver affects its reliability and fault-tolerance semantics. This is discussed [later](#receiver-reliability) in more detail. Any exception in the receiving threads should be caught and handled properly to avoid silent @@ -60,7 +60,7 @@ class CustomReceiver(host: String, port: Int) def onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -123,7 +123,7 @@ public class JavaCustomReceiver extends Receiver { public void onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -167,7 +167,7 @@ public class JavaCustomReceiver extends Receiver { The custom receiver can be used in a Spark Streaming application by using `streamingContext.receiverStream()`. This will create -input DStream using data received by the instance of custom receiver, as shown below +an input DStream using data received by the instance of custom receiver, as shown below:
@@ -206,22 +206,20 @@ there are two kinds of receivers based on their reliability and fault-tolerance and stored in Spark reliably (that is, replicated successfully). Usually, implementing this receiver involves careful consideration of the semantics of source acknowledgements. -1. *Unreliable Receiver* - These are receivers for unreliable sources that do not support - acknowledging. Even for reliable sources, one may implement an unreliable receiver that - do not go into the complexity of acknowledging correctly. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgement to a source. This can be used for sources that do not support acknowledgement, or even for reliable sources when one does not want or need to go into the complexity of acknowledgement. To implement a *reliable receiver*, you have to use `store(multiple-records)` to store data. -This flavour of `store` is a blocking call which returns only after all the given records have +This flavor of `store` is a blocking call which returns only after all the given records have been stored inside Spark. If the receiver's configured storage level uses replication (enabled by default), then this call returns after replication has completed. Thus it ensures that the data is reliably stored, and the receiver can now acknowledge the -source appropriately. This ensures that no data is caused when the receiver fails in the middle +source appropriately. This ensures that no data is lost when the receiver fails in the middle of replicating data -- the buffered data will not be acknowledged and hence will be later resent by the source. An *unreliable receiver* does not have to implement any of this logic. It can simply receive records from the source and insert them one-at-a-time using `store(single-record)`. While it does -not get the reliability guarantees of `store(multiple-records)`, it has the following advantages. +not get the reliability guarantees of `store(multiple-records)`, it has the following advantages: - The system takes care of chunking that data into appropriate sized blocks (look for block interval in the [Spark Streaming Programming Guide](streaming-programming-guide.html)). diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index c8ab146bcae0a..8d6e74370918f 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -99,6 +99,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index d6d5605948a5a..775d508d4879b 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,12 +2,12 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -74,15 +74,15 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. -This approach has the following advantages over the received-based approach (i.e. Approach 1). +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -118,6 +118,13 @@ Next, we discuss how to use this approach in your streaming application. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). +
+
+ from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
@@ -128,29 +135,60 @@ Next, we discuss how to use this approach in your streaming application.
- directKafkaStream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] - // offsetRanges.length = # of Kafka partitions being consumed - ... + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... }
- directKafkaStream.foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges - // offsetRanges.length = # of Kafka partitions being consumed - ... - return null; - } + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; } + } );
+
+ Not supported yet
+
You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. +3. **Deploying:** This is same as the first approach, for Scala, Java and Python. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 379eb513d521e..aa9749afbc867 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -32,7 +32,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream val kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. @@ -44,7 +45,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. @@ -54,19 +56,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream - - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from - - The application name used in the streaming context becomes the Kinesis application name + - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis + sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. - - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + - If the table exists but has incorrect checkpoint information (for a different stream, or + old expired sequenced numbers), then there may be temporary errors. + - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + - `[region name]`: Valid Kinesis region names can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html). + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -122,12 +128,12 @@ To run the example,
- bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
- bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
@@ -136,7 +142,7 @@ To run the example, - To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. - bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10 This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 42b33947873b0..b784d59666fec 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ, Kinesis or TCP sockets can be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -52,7 +52,7 @@ different languages. **Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream transformations and almost all the output operations available in Scala and Java interfaces. -However, it has only support for basic sources like text files and text data over sockets. +However, it only has support for basic sources like text files and text data over sockets. APIs for additional sources, like Kafka and Flume, will be available in the future. Further information about available features in the Python API are mentioned throughout this document; look out for the tag @@ -69,15 +69,15 @@ do is as follows.
-First, we import the names of the Spark Streaming classes, and some implicit -conversions from StreamingContext into our environment, to add useful methods to +First, we import the names of the Spark Streaming classes and some implicit +conversions from StreamingContext into our environment in order to add useful methods to other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight scala %} import org.apache.spark._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. // The master requires 2 cores to prevent from a starvation scenario. @@ -96,7 +96,7 @@ val lines = ssc.socketTextStream("localhost", 9999) This `lines` DStream represents the stream of data that will be received from the data server. Each record in this DStream is a line of text. Next, we want to split the lines by -space into words. +space characters into words. {% highlight scala %} // Split each line into words @@ -109,7 +109,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Next, we want to count these words. {% highlight scala %} -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) @@ -463,7 +463,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `StreamingContext` object can also be created from an existing `SparkContext` object. @@ -498,7 +498,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. @@ -531,7 +531,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details.
@@ -549,7 +549,7 @@ After a context is defined, you have to do the following. - Once a context has been started, no new streaming computations can be set up or added to it. - Once a context has been stopped, it cannot be restarted. - Only one StreamingContext can be active in a JVM at the same time. -- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set the optional parameter of `stop()` called `stopSparkContext` to false. - A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. *** @@ -583,7 +583,7 @@ the `flatMap` operation is applied on each RDD in the `lines` DStream to generat These underlying RDD transformations are computed by the Spark engine. The DStream operations -hide most of these details and provide the developer with higher-level API for convenience. +hide most of these details and provide the developer with a higher-level API for convenience. These operations are discussed in detail in later sections. *** @@ -600,7 +600,7 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Example: file systems, socket connections, and Akka actors. + Examples: file systems, socket connections, and Akka actors. - *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -610,11 +610,11 @@ We are going to discuss some of the sources present in each category later in th Note that, if you want to receive multiple streams of data in parallel in your streaming application, you can create multiple input DStreams (discussed further in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section). This will -create multiple receivers which will simultaneously receive multiple data streams. But note that -Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the -Spark Streaming application. Hence, it is important to remember that Spark Streaming application +create multiple receivers which will simultaneously receive multiple data streams. But note that a +Spark worker/executor is a long-running task, hence it occupies one of the cores allocated to the +Spark Streaming application. Therefore, it is important to remember that a Spark Streaming application needs to be allocated enough cores (or threads, if running locally) to process the received data, -as well as, to run the receiver(s). +as well as to run the receiver(s). ##### Points to remember {:.no_toc} @@ -623,13 +623,13 @@ as well as, to run the receiver(s). Either of these means that only one thread will be used for running tasks locally. If you are using a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when - running locally, always use "local[*n*]" as the master URL where *n* > number of receivers to run - (see [Spark Properties](configuration.html#spark-properties.html) for information on how to set + running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run + (see [Spark Properties](configuration.html#spark-properties) for information on how to set the master). - Extending the logic to running on a cluster, the number of cores allocated to the Spark Streaming - application must be more than the number of receivers. Otherwise the system will receive data, but - not be able to process them. + application must be more than the number of receivers. Otherwise the system will receive data, but + not be able to process it. ### Basic Sources {:.no_toc} @@ -639,7 +639,7 @@ which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides methods for creating DStreams from files and Akka actors as input sources. -- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as +- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as:
@@ -682,14 +682,14 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.3, +Python API As of Spark {{site.SPARK_VERSION_SHORT}}, out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts -of dependencies, the functionality to create DStreams from these sources have been moved to separate -libraries, that can be [linked](#linking) to explicitly when necessary. For example, if you want to -create a DStream using data from Twitter's stream of tweets, you have to do the following. +of dependencies, the functionality to create DStreams from these sources has been moved to separate +libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to +create a DStream using data from Twitter's stream of tweets, you have to do the following: 1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. @@ -719,11 +719,11 @@ TwitterUtils.createStream(jssc); Note that these advanced sources are not available in the Spark shell, hence applications based on these advanced sources cannot be tested in the shell. If you really want to use them in the Spark shell you will have to download the corresponding Maven artifact's JAR along with its dependencies -and it in the classpath. +and add it to the classpath. Some of these advanced sources are as follows. -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. @@ -743,7 +743,7 @@ Some of these advanced sources are as follows. Python API This is not yet supported in Python. -Input DStreams can also be created out of custom data sources. All you have to do is implement an +Input DStreams can also be created out of custom data sources. All you have to do is implement a user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the [Custom Receiver Guide](streaming-custom-receivers.html) for details. @@ -753,14 +753,12 @@ Guide](streaming-custom-receivers.html) for details. There can be two kinds of data sources based on their *reliability*. Sources (like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving -data from these *reliable* sources acknowledge the received data correctly, it can be ensured -that no data gets lost due to any kind of failure. This leads to two kinds of receivers. +data from these *reliable* sources acknowledges the received data correctly, it can be ensured +that no data will be lost due to any kind of failure. This leads to two kinds of receivers: -1. *Reliable Receiver* - A *reliable receiver* correctly acknowledges a reliable - source that the data has been received and stored in Spark with replication. -1. *Unreliable Receiver* - These are receivers for sources that do not support acknowledging. Even - for reliable sources, one may implement an unreliable receiver that do not go into the complexity - of acknowledging correctly. +1. *Reliable Receiver* - A *reliable receiver* correctly sends acknowledgment to a reliable + source when the data has been received and stored in Spark with replication. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgment to a source. This can be used for sources that do not support acknowledgment, or even for reliable sources when one does not want or need to go into the complexity of acknowledgment. The details of how to write a reliable receiver are discussed in the [Custom Receiver Guide](streaming-custom-receivers.html). @@ -828,7 +826,7 @@ Some of the common ones are as follows. cogroup(otherStream, [numTasks]) - When called on DStream of (K, V) and (K, W) pairs, return a new DStream of + When called on a DStream of (K, V) and (K, W) pairs, return a new DStream of (K, Seq[V], Seq[W]) tuples. @@ -852,13 +850,13 @@ A few of these transformations are worth discussing in more detail. The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. -1. Define the state - The state can be of arbitrary data type. +1. Define the state - The state can be an arbitrary data type. 1. Define the state update function - Specify with a function how to update the state using the -previous state and the new values from input stream. +previous state and the new values from an input stream. Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We -define the update function as +define the update function as:
@@ -947,7 +945,7 @@ operation that is not exposed in the DStream API. For example, the functionality of joining every batch in a data stream with another dataset is not directly exposed in the DStream API. However, you can easily use `transform` to do this. This enables very powerful possibilities. For example, -if you want to do real-time data cleaning by joining the input data stream with precomputed +one can do real-time data cleaning by joining the input data stream with precomputed spam information (maybe generated with Spark as well) and then filtering based on it.
@@ -991,13 +989,14 @@ cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(.
-In fact, you can also use [machine learning](mllib-guide.html) and -[graph computation](graphx-programming-guide.html) algorithms in the `transform` method. +Note that the supplied function gets called in every batch interval. This allows you to do +time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, +etc. can be changed between batches. #### Window Operations {:.no_toc} Spark Streaming also provides *windowed computations*, which allow you to apply -transformations over a sliding window of data. This following figure illustrates this sliding +transformations over a sliding window of data. The following figure illustrates this sliding window.

@@ -1009,11 +1008,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the -RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time +RDDs of the windowed DStream. In this specific case, the operation is applied over the last 3 time units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. - * window length - The duration of the window (3 in the figure) + * window length - The duration of the window (3 in the figure). * sliding interval - The interval at which the window operation is performed (2 in the figure). @@ -1021,7 +1020,7 @@ These two parameters must be multiples of the batch interval of the source DStre figure). Let's illustrate the window operations with an example. Say, you want to extend the -[earlier example](#a-quick-example) by generating word counts over last 30 seconds of data, +[earlier example](#a-quick-example) by generating word counts over the last 30 seconds of data, every 10 seconds. To do this, we have to apply the `reduceByKey` operation on the `pairs` DStream of `(word, 1)` pairs over the last 30 seconds of data. This is done using the operation `reduceByKeyAndWindow`. @@ -1096,13 +1095,13 @@ said two parameters - windowLength and slideInterval. reduceByKeyAndWindow(func, invFunc, windowLength, slideInterval, [numTasks]) - A more efficient version of the above reduceByKeyAndWindow() where the reduce + A more efficient version of the above reduceByKeyAndWindow() where the reduce value of each window is calculated incrementally using the reduce values of the previous window. - This is done by reducing the new data that enter the sliding window, and "inverse reducing" the - old data that leave the window. An example would be that of "adding" and "subtracting" counts - of keys as the window slides. However, it is applicable to only "invertible reduce functions", + This is done by reducing the new data that enters the sliding window, and "inverse reducing" the + old data that leaves the window. An example would be that of "adding" and "subtracting" counts + of keys as the window slides. However, it is applicable only to "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as - parameter invFunc. Like in reduceByKeyAndWindow, the number of reduce tasks + parameter invFunc). Like in reduceByKeyAndWindow, the number of reduce tasks is configurable through an optional argument. Note that [checkpointing](#checkpointing) must be enabled for using this operation. @@ -1224,7 +1223,7 @@ For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.stre *** ## Output Operations on DStreams -Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Output operations allow DStream's data to be pushed out to external systems like a database or a file systems. Since the output operations actually allow the transformed data to be consumed by external systems, they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). Currently, the following output operations are defined: @@ -1233,7 +1232,7 @@ Currently, the following output operations are defined: Output OperationMeaning print() - Prints first ten elements of every batch of data in a DStream on the driver node running + Prints the first ten elements of every batch of data in a DStream on the driver node running the streaming application. This is useful for development and debugging.
Python API This is called @@ -1242,12 +1241,12 @@ Currently, the following output operations are defined: saveAsTextFiles(prefix, [suffix]) - Save this DStream's contents as a text files. The file name at each batch interval is + Save this DStream's contents as text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". saveAsObjectFiles(prefix, [suffix]) - Save this DStream's contents as a SequenceFile of serialized Java objects. The file + Save this DStream's contents as SequenceFiles of serialized Java objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
@@ -1257,7 +1256,7 @@ Currently, the following output operations are defined: saveAsHadoopFiles(prefix, [suffix]) - Save this DStream's contents as a Hadoop file. The file name at each batch interval is + Save this DStream's contents as Hadoop files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
Python API This is not available in @@ -1267,7 +1266,7 @@ Currently, the following output operations are defined: foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from - the stream. This function should push the data in each RDD to a external system, like saving the RDD to + the stream. This function should push the data in each RDD to an external system, such as saving the RDD to files, or writing it over the network to a database. Note that the function func is executed in the driver process running the streaming application, and will usually have RDD actions in it that will force the computation of the streaming RDDs. @@ -1277,14 +1276,14 @@ Currently, the following output operations are defined: ### Design Patterns for using foreachRDD {:.no_toc} -`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +`dstream.foreachRDD` is a powerful primitive that allows data to be sent out to external systems. However, it is important to understand how to use this primitive correctly and efficiently. Some of the common mistakes to avoid are as follows. Often writing data to external system requires creating a connection object (e.g. TCP connection to a remote server) and using it to send data to a remote system. For this purpose, a developer may inadvertently try creating a connection object at -the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +the Spark driver, and then try to use it in a Spark worker to save records in the RDDs. For example (in Scala),

@@ -1346,7 +1345,7 @@ dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use -`rdd.foreachPartition` - create a single connection object and send all the records in a RDD +`rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
@@ -1427,26 +1426,6 @@ You can easily use [DataFrames and SQL](sql-programming-guide.html) operations o
{% highlight scala %} -/** Lazily instantiated singleton instance of SQLContext */ -object SQLContextSingleton { - @transient private var instance: SQLContext = null - - // Instantiate SQLContext on demand - def getInstance(sparkContext: SparkContext): SQLContext = synchronized { - if (instance == null) { - instance = new SQLContext(sparkContext) - } - instance - } -} - -... - -/** Case class for converting RDD to DataFrame */ -case class Row(word: String) - -... - /** DataFrame operations inside your streaming program */ val words: DStream[String] = ... @@ -1454,11 +1433,11 @@ val words: DStream[String] = ... words.foreachRDD { rdd => // Get the singleton instance of SQLContext - val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + val sqlContext = SQLContext.getOrCreate(rdd.sparkContext) import sqlContext.implicits._ - // Convert RDD[String] to RDD[case class] to DataFrame - val wordsDataFrame = rdd.map(w => Row(w)).toDF() + // Convert RDD[String] to DataFrame + val wordsDataFrame = rdd.toDF("word") // Register as table wordsDataFrame.registerTempTable("words") @@ -1476,19 +1455,6 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
{% highlight java %} -/** Lazily instantiated singleton instance of SQLContext */ -class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { - if (instance == null) { - instance = new SQLContext(sparkContext); - } - return instance; - } -} - -... - /** Java Bean class for converting RDD to DataFrame */ public class JavaRow implements java.io.Serializable { private String word; @@ -1512,7 +1478,9 @@ words.foreachRDD( new Function2, Time, Void>() { @Override public Void call(JavaRDD rdd, Time time) { - SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Get the singleton instance of SQLContext + SQLContext sqlContext = SQLContext.getOrCreate(rdd.context()); // Convert RDD[String] to RDD[case class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { @@ -1581,7 +1549,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
-You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember a sufficient amount of streaming data such that the query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. @@ -1594,7 +1562,7 @@ You can also easily use machine learning algorithms provided by [MLlib](mllib-gu ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, -using `persist()` method on a DStream will automatically persist every RDD of that DStream in +using the `persist()` method on a DStream will automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. @@ -1606,28 +1574,27 @@ default persistence level is set to replicate the data to two nodes for fault-to Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More -information on different persistence levels can be found in -[Spark Programming Guide](programming-guide.html#rdd-persistence). +information on different persistence levels can be found in the [Spark Programming Guide](programming-guide.html#rdd-persistence). *** ## Checkpointing A streaming application must operate 24/7 and hence must be resilient to failures unrelated to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible, -Spark Streaming needs to *checkpoints* enough information to a fault- +Spark Streaming needs to *checkpoint* enough information to a fault- tolerant storage system such that it can recover from failures. There are two types of data that are checkpointed. - *Metadata checkpointing* - Saving of the information defining the streaming computation to fault-tolerant storage like HDFS. This is used to recover from failure of the node running the driver of the streaming application (discussed in detail later). Metadata includes: - + *Configuration* - The configuration that were used to create the streaming application. + + *Configuration* - The configuration that was used to create the streaming application. + *DStream operations* - The set of DStream operations that define the streaming application. + *Incomplete batches* - Batches whose jobs are queued but have not completed yet. - *Data checkpointing* - Saving of the generated RDDs to reliable storage. This is necessary in some *stateful* transformations that combine data across multiple batches. In such - transformations, the generated RDDs depends on RDDs of previous batches, which causes the length - of the dependency chain to keep increasing with time. To avoid such unbounded increase in recovery + transformations, the generated RDDs depend on RDDs of previous batches, which causes the length + of the dependency chain to keep increasing with time. To avoid such unbounded increases in recovery time (proportional to dependency chain), intermediate RDDs of stateful transformations are periodically *checkpointed* to reliable storage (e.g. HDFS) to cut off the dependency chains. @@ -1641,10 +1608,10 @@ transformations are used. Checkpointing must be enabled for applications with any of the following requirements: - *Usage of stateful transformations* - If either `updateStateByKey` or `reduceByKeyAndWindow` (with - inverse function) is used in the application, then the checkpoint directory must be provided for - allowing periodic RDD checkpointing. + inverse function) is used in the application, then the checkpoint directory must be provided to + allow for periodic RDD checkpointing. - *Recovering from failures of the driver running the application* - Metadata checkpoints are used - for to recover with progress information. + to recover with progress information. Note that simple streaming applications without the aforementioned stateful transformations can be run without enabling checkpointing. The recovery from driver failures will also be partial in @@ -1659,7 +1626,7 @@ Checkpointing can be enabled by setting a directory in a fault-tolerant, reliable file system (e.g., HDFS, S3, etc.) to which the checkpoint information will be saved. This is done by using `streamingContext.checkpoint(checkpointDirectory)`. This will allow you to use the aforementioned stateful transformations. Additionally, -if you want make the application recover from driver failures, you should rewrite your +if you want to make the application recover from driver failures, you should rewrite your streaming application to have the following behavior. + When the program is being started for the first time, it will create a new StreamingContext, @@ -1780,18 +1747,17 @@ You can also explicitly create a `StreamingContext` from the checkpoint data and In addition to using `getOrCreate` one also needs to ensure that the driver process gets restarted automatically on failure. This can only be done by the deployment infrastructure that is used to run the application. This is further discussed in the -[Deployment](#deploying-applications.html) section. +[Deployment](#deploying-applications) section. Note that checkpointing of RDDs incurs the cost of saving to reliable storage. This may cause an increase in the processing time of those batches where RDDs get checkpointed. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too infrequently -causes the lineage and task sizes to grow which may have detrimental effects. For stateful +causes the lineage and task sizes to grow, which may have detrimental effects. For stateful transformations that require RDD checkpointing, the default interval is a multiple of the batch interval that is at least 10 seconds. It can be set by using -`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 times of -sliding interval of a DStream is good setting to try. +`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 sliding intervals of a DStream is a good setting to try. *** @@ -1864,17 +1830,17 @@ To run a Spark Streaming applications, you need to have the following. {:.no_toc} If a running Spark Streaming application needs to be upgraded with new -application code, then there are two possible mechanism. +application code, then there are two possible mechanisms. - The upgraded Spark Streaming application is started and run in parallel to the existing application. -Once the new one (receiving the same data as the old one) has been warmed up and ready +Once the new one (receiving the same data as the old one) has been warmed up and is ready for prime time, the old one be can be brought down. Note that this can be done for data sources that support sending the data to two destinations (i.e., the earlier and upgraded applications). - The existing application is shutdown gracefully (see [`StreamingContext.stop(...)`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) or [`JavaStreamingContext.stop(...)`](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) -for graceful shutdown options) which ensure data that have been received is completely +for graceful shutdown options) which ensure data that has been received is completely processed before shutdown. Then the upgraded application can be started, which will start processing from the same point where the earlier application left off. Note that this can be done only with input sources that support source-side buffering @@ -1909,10 +1875,10 @@ The following two metrics in web UI are particularly important: to finish. If the batch processing time is consistently more than the batch interval and/or the queueing -delay keeps increasing, then it indicates the system is -not able to process the batches as fast they are being generated and falling behind. +delay keeps increasing, then it indicates that the system is +not able to process the batches as fast they are being generated and is falling behind. In that case, consider -[reducing](#reducing-the-processing-time-of-each-batch) the batch processing time. +[reducing](#reducing-the-batch-processing-times) the batch processing time. The progress of a Spark Streaming program can also be monitored using the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface, @@ -1923,8 +1889,8 @@ and it is likely to be improved upon (i.e., more information reported) in the fu *************************************************************************************************** # Performance Tuning -Getting the best performance of a Spark Streaming application on a cluster requires a bit of -tuning. This section explains a number of the parameters and configurations that can tuned to +Getting the best performance out of a Spark Streaming application on a cluster requires a bit of +tuning. This section explains a number of the parameters and configurations that can be tuned to improve the performance of you application. At a high level, you need to consider two things: 1. Reducing the processing time of each batch of data by efficiently using cluster resources. @@ -1934,12 +1900,12 @@ improve the performance of you application. At a high level, you need to conside ## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of -each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section +each batch. These have been discussed in detail in the [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism in Data Receiving {:.no_toc} -Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized +Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. @@ -1947,7 +1913,7 @@ Receiving multiple data streams can therefore be achieved by creating multiple i and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers, -allowing data to be received in parallel, and increasing overall throughput. These multiple +allowing data to be received in parallel, thus increasing overall throughput. These multiple DStreams can be unioned together to create a single DStream. Then the transformations that were being applied on a single input DStream can be applied on the unified stream. This is done as follows. @@ -1971,16 +1937,24 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre unifiedStream.print(); {% endhighlight %}
+
+{% highlight python %} +numStreams = 5 +kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] +unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
Another parameter that should be considered is the receiver's blocking interval, which is determined by the [configuration parameter](configuration.html#spark-streaming) `spark.streaming.blockInterval`. For most receivers, the received data is coalesced together into blocks of data before storing inside Spark's memory. The number of blocks in each batch -determines the number of tasks that will be used to process those +determines the number of tasks that will be used to process the received data in a map-like transformation. The number of tasks per receiver per batch will be approximately (batch interval / block interval). For example, block interval of 200 ms will -create 10 tasks per 2 second batches. Too low the number of tasks (that is, less than the number +create 10 tasks per 2 second batches. If the number of tasks is too low (that is, less than the number of cores per machine), then it will be inefficient as all available cores will not be used to process the data. To increase the number of tasks for a given batch interval, reduce the block interval. However, the recommended minimum value of block interval is about 50 ms, @@ -1988,7 +1962,7 @@ below which the task launching overheads may be a problem. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across specified number of machines in the cluster +This distributes the received batches of data across the specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing @@ -1996,7 +1970,7 @@ before further processing. Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is controlled by -the`spark.default.parallelism` [configuration property](configuration.html#spark-properties). You +the `spark.default.parallelism` [configuration property](configuration.html#spark-properties). You can pass the level of parallelism as an argument (see [`PairDStreamFunctions`](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions) documentation), or set the `spark.default.parallelism` @@ -2004,20 +1978,20 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized. +The overheads of data serialization can be reduced by tuning the serialization formats. In the case of streaming, there are two types of data that are being serialized. -* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is insufficient to hold all of the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operations persist data in memory as they would be processed multiple times. However, unlike the Spark Core default of [StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$), persisted RDDs generated by streaming computations are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) by default to minimize GC overheads. -In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization) for more details. For Kryo, consider registering custom classes, and disabling object reference tracking (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). -In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of a few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead -of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second +of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: * **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task @@ -2036,7 +2010,7 @@ thus allowing sub-second batch size to be viable. For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by -[monitoring](#monitoring) the processing times in the streaming web UI, where the batch +[monitoring](#monitoring-applications) the processing times in the streaming web UI, where the batch processing time should be less than the batch interval. Depending on the nature of the streaming @@ -2049,35 +2023,35 @@ production can be sustained. A good approach to figure out the right batch size for your application is to test it with a conservative batch interval (say, 5-10 seconds) and a low data rate. To verify whether the system -is able to keep up with data rate, you can check the value of the end-to-end delay experienced +is able to keep up with the data rate, you can check the value of the end-to-end delay experienced by each processed batch (either look for "Total delay" in Spark driver log4j logs, or use the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface). If the delay is maintained to be comparable to the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the -data rate and/or reducing the batch size. Note that momentary increase in the delay due to -temporary data rate increases maybe fine as long as the delay reduces back to a low value +data rate and/or reducing the batch size. Note that a momentary increase in the delay due to +temporary data rate increases may be fine as long as the delay reduces back to a low value (i.e., less than batch size). *** ## Memory Tuning -Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail +Tuning the memory usage and GC behavior of Spark applications has been discussed in great detail in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. -The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low. +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on the last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then the necessary memory will be low. -In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. +In general, since the data received through receivers is stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. -Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. +Another aspect of memory tuning is garbage collection. For a streaming application that requires low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. -There are a few parameters that can help you tune the memory usage and GC overheads. +There are a few parameters that can help you tune the memory usage and GC overheads: -* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. -* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data. -Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using a window operation of 10 minutes, then Spark Streaming will keep around the last 10 minutes of data, and actively throw away older data. +Data can be retained for a longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. * **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more @@ -2107,18 +2081,18 @@ re-computed from the original fault-tolerant dataset using the lineage of operat 1. Assuming that all of the RDD transformations are deterministic, the data in the final transformed RDD will always be the same irrespective of failures in the Spark cluster. -Spark operates on data on fault-tolerant file systems like HDFS or S3. Hence, +Spark operates on data in fault-tolerant file systems like HDFS or S3. Hence, all of the RDDs generated from the fault-tolerant data are also fault-tolerant. However, this is not the case for Spark Streaming as the data in most cases is received over the network (except when `fileStream` is used). To achieve the same fault-tolerance properties for all of the generated RDDs, the received data is replicated among multiple Spark executors in worker nodes in the cluster (default replication factor is 2). This leads to two kinds of data in the -system that needs to recovered in the event of failures: +system that need to recovered in the event of failures: 1. *Data received and replicated* - This data survives failure of a single worker node as a copy - of it exists on one of the nodes. + of it exists on one of the other nodes. 1. *Data received but buffered for replication* - Since this is not replicated, - the only way to recover that data is to get it again from the source. + the only way to recover this data is to get it again from the source. Furthermore, there are two kinds of failures that we should be concerned about: @@ -2145,13 +2119,13 @@ In any stream processing system, broadly speaking, there are three steps in proc 1. *Receiving the data*: The data is received from sources using Receivers or otherwise. -1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations. +1. *Transforming the data*: The received data is transformed using DStream and RDD transformations. 1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. -If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide an exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. -1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection. +1. *Receiving the data*: Different input sources provide different guarantees. This is discussed in detail in the next subsection. 1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. @@ -2163,9 +2137,9 @@ Different input sources provide different guarantees, ranging from _at-least onc ### With Files {:.no_toc} -If all of the input data is already present in a fault-tolerant files system like -HDFS, Spark Streaming can always recover from any failure and process all the data. This gives -*exactly-once* semantics, that all the data will be processed exactly once no matter what fails. +If all of the input data is already present in a fault-tolerant file system like +HDFS, Spark Streaming can always recover from any failure and process all of the data. This gives +*exactly-once* semantics, meaning all of the data will be processed exactly once no matter what fails. ### With Receiver-based Sources {:.no_toc} @@ -2174,21 +2148,21 @@ scenario and the type of receiver. As we discussed [earlier](#receiver-reliability), there are two types of receivers: 1. *Reliable Receiver* - These receivers acknowledge reliable sources only after ensuring that - the received data has been replicated. If such a receiver fails, - the buffered (unreplicated) data does not get acknowledged to the source. If the receiver is - restarted, the source will resend the data, and therefore no data will be lost due to the failure. -1. *Unreliable Receiver* - Such receivers can lose data when they fail due to worker - or driver failures. + the received data has been replicated. If such a receiver fails, the source will not receive + acknowledgment for the buffered (unreplicated) data. Therefore, if the receiver is + restarted, the source will resend the data, and no data will be lost due to the failure. +1. *Unreliable Receiver* - Such receivers do *not* send acknowledgment and therefore *can* lose + data when they fail due to worker or driver failures. Depending on what type of receivers are used we achieve the following semantics. If a worker node fails, then there is no data loss with reliable receivers. With unreliable receivers, data received but not replicated can get lost. If the driver node fails, -then besides these losses, all the past data that was received and replicated in memory will be +then besides these losses, all of the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee. +ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2234,7 +2208,7 @@ The following table summarizes the semantics under failures: ### With Kafka Direct API {:.no_toc} -In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). ## Semantics of output operations {:.no_toc} @@ -2248,9 +2222,16 @@ additional effort may be necessary to achieve exactly-once semantics. There are - *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. - - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. - - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update. + - Use the batch time (available in `foreachRDD`) and the partition index of the RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else, if this was already committed, skip the update. + dstream.foreachRDD { (rdd, time) => + rdd.foreachPartition { partitionIterator => + val partitionId = TaskContext.get.partitionId() + val uniqueId = generateUniqueId(time.milliseconds, partitionId) + // use this uniqueId to transactionally commit the data in partitionIterator + } + } *************************************************************************************************** *************************************************************************************************** @@ -2325,7 +2306,7 @@ package and renamed for better clarity. - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and - [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) + [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 84629cb9a0ca0..18ccbc0a3edd0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.3.1" +SPARK_EC2_VERSION = "1.4.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -70,6 +70,7 @@ "1.2.1", "1.3.0", "1.3.1", + "1.4.0", ]) SPARK_TACHYON_MAP = { @@ -82,6 +83,7 @@ "1.2.1": "0.5.0", "1.3.0": "0.5.0", "1.3.1": "0.5.0", + "1.4.0": "0.6.4", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -89,7 +91,7 @@ # Default location to get the spark-ec2 scripts (and ami-list) from DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" def setup_external_libs(libs): @@ -287,6 +289,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -300,6 +306,13 @@ def parse_args(): "--private-ips", action="store_true", default=False, help="Use private IPs for instances rather than public if VPC/subnet " + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -356,7 +369,7 @@ def get_validate_spark_version(version, repo): # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-05-08 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", @@ -398,6 +411,11 @@ def get_validate_spark_version(version, repo): "m3.large": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", "r3.large": "hvm", "r3.xlarge": "hvm", "r3.2xlarge": "hvm", @@ -407,6 +425,7 @@ def get_validate_spark_version(version, repo): "t2.micro": "hvm", "t2.small": "hvm", "t2.medium": "hvm", + "t2.large": "hvm", } @@ -486,6 +505,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) # HDFS NFS gateway requires 111,2049,4242 for tcp & udp master_group.authorize('tcp', 111, 111, authorized_address) master_group.authorize('udp', 111, 111, authorized_address) @@ -590,7 +611,8 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -635,16 +657,19 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( s=num_slaves_this_zone, @@ -666,32 +691,43 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 print("Waiting for AWS to propagate instance metadata...") - time.sleep(5) - # Give the instances descriptive names + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) @@ -909,7 +945,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-05-08 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, @@ -951,6 +987,11 @@ def get_num_disks(instance_type): "m3.large": 1, "m3.xlarge": 2, "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, "r3.large": 1, "r3.xlarge": 1, "r3.2xlarge": 1, @@ -960,6 +1001,7 @@ def get_num_disks(instance_type): "t2.micro": 0, "t2.small": 0, "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index ec533d174ebdc..9df26ffca5775 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Create a model, and return it. return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); + } } /** diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 1456c87312841..0ea7cfb7025a0 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -68,7 +68,7 @@ def closestPoint(p, centers): closest = data.map( lambda p: (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (p1, c1), (p2, c2): (p1 + p2, c1 + c2)) + lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) newPoints = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 0000000000000..dcd6a0fc6ff91 --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in xrange(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala new file mode 100644 index 0000000000000..1f12034ce0f57 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.io.File + +import scala.io.Source._ + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.SparkContext._ + +/** + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ +object DFSReadWriteTest { + + private var localFilePath: File = new File(".") + private var dfsDirPath: String = "" + + private val NPARAMS = 2 + + private def readFile(filename: String): List[String] = { + val lineIter: Iterator[String] = fromFile(filename).getLines() + val lineList: List[String] = lineIter.toList + lineList + } + + private def printUsage(): Unit = { + val usage: String = "DFS Read-Write Test\n" + + "\n" + + "Usage: localFile dfsDir\n" + + "\n" + + "localFile - (string) local file to use in test\n" + + "dfsDir - (string) DFS directory for read/write tests\n" + + println(usage) + } + + private def parseArgs(args: Array[String]): Unit = { + if (args.length != NPARAMS) { + printUsage() + System.exit(1) + } + + var i = 0 + + localFilePath = new File(args(i)) + if (!localFilePath.exists) { + System.err.println("Given path (" + args(i) + ") does not exist.\n") + printUsage() + System.exit(1) + } + + if (!localFilePath.isFile) { + System.err.println("Given path (" + args(i) + ") is not a file.\n") + printUsage() + System.exit(1) + } + + i += 1 + dfsDirPath = args(i) + } + + def runLocalWordCount(fileContents: List[String]): Int = { + fileContents.flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .groupBy(w => w) + .mapValues(_.size) + .values + .sum + } + + def main(args: Array[String]): Unit = { + parseArgs(args) + + println("Performing local word count") + val fileContents = readFile(localFilePath.toString()) + val localWordCount = runLocalWordCount(fileContents) + + println("Creating SparkConf") + val conf = new SparkConf().setAppName("DFS Read Write Test") + + println("Creating SparkContext") + val sc = new SparkContext(conf) + + println("Writing local file to DFS") + val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val fileRDD = sc.parallelize(fileContents) + fileRDD.saveAsTextFile(dfsFilename) + + println("Reading file from DFS and running Word Count") + val readFileRDD = sc.textFile(dfsFilename) + + val dfsWordCount = readFileRDD + .flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .map(w => (w, 1)) + .countByKey() + .values + .sum + + sc.stop() + + if (localWordCount == dfsWordCount) { + println(s"Success! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) agree.") + } else { + println(s"Failure! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) disagree.") + } + + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3ee456edbe01e..7b8cc21ed8982 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String) // Create a model, and return it. new MyLogisticRegressionModel(uid, weights).setParent(this) } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index b12f833ce94c8..3cf193f353fbc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -145,9 +145,9 @@ object LogisticRegressionExample { val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") - val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] // Print the weights and intercept for logistic regression. - println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 7a7dccc3d0922..0664cfb2021e1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,10 +35,6 @@ http://spark.apache.org/ - - org.apache.commons - commons-lang3 - org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index dc2a4ab138e18..719fca0938b3a 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.streaming.flume.sink +import java.util.UUID import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.flume.Channel -import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process @@ -53,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha // Since the new txn may not have the same sequence number we must guard against accidentally // committing a new transaction. To reduce the probability of that happening a random string is // prepended to the sequence number. Does not change for life of sink - private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqBase = UUID.randomUUID().toString.substring(0, 8) private val seqCounter = new AtomicLong(0) // Protected by `sequenceNumberToProcessor` diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 060c2f23eded8..876456c964770 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -120,8 +120,7 @@ class DirectKafkaInputDStream[ context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) // Report the record number of this batch interval to InputInfoTracker. - val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum - val inputInfo = InputInfo(id, numRecords) + val inputInfo = InputInfo(id, rdd.count) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -153,10 +152,7 @@ class DirectKafkaInputDStream[ override def restore() { // this is assuming that the topics don't change during execution, which is true currently val topics = fromOffsets.keySet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 65d51d87f8486..3e6b937af57b0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -360,6 +360,14 @@ private[spark] object KafkaCluster { type Err = ArrayBuffer[Throwable] + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a1b4a12e5d6a0..c5cd2154772ac 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.kafka +import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator @@ -60,6 +62,48 @@ class KafkaRDD[ }.toArray } + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray, + allowLocal = true) + res.foreach(buf ++= _) + buf.toArray + } + override def getPreferredLocations(thePart: Partition): Seq[String] = { val part = thePart.asInstanceOf[KafkaRDDPartition] // TODO is additional hostname resolution necessary here diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a842a6f17766f..a660d2a00c35d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -35,4 +35,7 @@ class KafkaRDDPartition( val untilOffset: Long, val host: String, val port: Int -) extends Partition +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0b8a391a2c569..0e33362d34acd 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -158,15 +158,31 @@ object KafkaUtils { /** get leaders for the given offset ranges, or throw an exception */ private def leadersForRanges( - kafkaParams: Map[String, String], + kc: KafkaCluster, offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { - val kc = new KafkaCluster(kafkaParams) val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) - leaders + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } } /** @@ -191,7 +207,9 @@ object KafkaUtils { offsetRanges: Array[OffsetRange] ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) - val leaders = leadersForRanges(kafkaParams, offsetRanges) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) } @@ -225,8 +243,9 @@ object KafkaUtils { leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) val leaderMap = if (leaders.isEmpty) { - leadersForRanges(kafkaParams, offsetRanges) + leadersForRanges(kc, offsetRanges) } else { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { @@ -234,6 +253,7 @@ object KafkaUtils { }.toMap } val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } @@ -399,7 +419,7 @@ object KafkaUtils { val kc = new KafkaCluster(kafkaParams) val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - (for { + val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { kc.getEarliestLeaderOffsets(topicPartitions) @@ -412,10 +432,8 @@ object KafkaUtils { } new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( ssc, kafkaParams, fromOffsets, messageHandler) - }).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + } + KafkaCluster.checkErrors(result) } /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 9c3dfeb8f5928..2675042666304 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -55,6 +55,12 @@ final class OffsetRange private( val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + override def equals(obj: Any): Boolean = obj match { case that: OffsetRange => this.topic == that.topic && diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 4c1d6a03eb2b8..02cd24a35906f 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -18,9 +18,8 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Arrays; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import scala.Tuple2; @@ -34,6 +33,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -67,8 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { - String topic1 = "topic1"; - String topic2 = "topic2"; + final String topic1 = "topic1"; + final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -89,6 +91,17 @@ public void testKafkaStream() throws InterruptedException { StringDecoder.class, kafkaParams, topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } ).map( new Function, String>() { @Override @@ -116,12 +129,17 @@ public String call(MessageAndMetadata msgAndMd) throws Exception ); JavaDStream unifiedStream = stream1.union(stream2); - final HashSet result = new HashSet(); + final Set result = Collections.synchronizedSet(new HashSet()); unifiedStream.foreachRDD( new Function, Void>() { @Override public Void call(JavaRDD rdd) throws Exception { result.addAll(rdd.collect()); + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } return null; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 540f4ceabab47..e4c659215b767 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Random; +import java.util.*; import scala.Tuple2; @@ -94,7 +92,7 @@ public void testKafkaStream() throws InterruptedException { topics, StorageLevel.MEMORY_ONLY_SER()); - final HashMap result = new HashMap(); + final Map result = Collections.synchronizedMap(new HashMap()); JavaDStream words = stream.map( new Function, String>() { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 47bbfb605850a..8e1715f6dbb95 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -99,15 +99,24 @@ class DirectKafkaStreamSuite ssc, kafkaParams, topics) } - val allReceived = new ArrayBuffer[(String, String)] - - stream.foreachRDD { rdd => - // Get the offset ranges in the RDD - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, // and the number of items in the partition - val off = offsets(i) + val off = offsetRanges(i) val all = iter.toSeq val partSize = all.size val rangeSize = off.untilOffset - off.fromOffset @@ -162,7 +171,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -208,7 +217,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -324,7 +333,8 @@ class DirectKafkaStreamSuite ssc, kafkaParams, Set(topic)) } - val allReceived = new ArrayBuffer[(String, String)] + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] stream.foreachRDD { rdd => allReceived ++= rdd.collect() } ssc.start() @@ -350,8 +360,8 @@ class DirectKafkaStreamSuite } object DirectKafkaStreamSuite { - val collectedData = new mutable.ArrayBuffer[String]() - var total = -1L + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + @volatile var total = -1L class InputInfoCollector extends StreamingListener { val numRecordsSubmitted = new AtomicLong(0L) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index d5baf5fd89994..f52a738afd65b 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -55,8 +55,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { test("basic usage") { val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) - val messages = Set("the", "quick", "brown", "fox") - kafkaTestUtils.sendMessages(topic, messages.toArray) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") @@ -67,7 +67,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet - assert(received === messages) + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } } test("iterator boundary conditions") { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 8ee2cc660f849..797b07f80d8ee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -65,7 +65,7 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() + val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => @@ -77,10 +77,7 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent.size === result.size) - sent.keys.foreach { k => - assert(sent(k) === result(k).toInt) - } + assert(sent === result) } } } diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c6f60bc907438..c242e7a57b9ab 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -66,7 +66,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/launcher/pom.xml b/launcher/pom.xml index 48dd0d5f9106b..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -49,7 +49,7 @@ org.mockito - mockito-all + mockito-core test @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index d80abf2a8676e..de85720febf23 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -93,6 +93,9 @@ public List buildCommand(Map env) throws IOException { toolsDir.getAbsolutePath(), className); javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); diff --git a/mllib/pom.xml b/mllib/pom.xml index b16058ddc203a..a5db14407b4fc 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -106,7 +106,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index e9a5d7c0e7988..57e416591de69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { paramMaps.map(fit(dataset, _)) } - override def copy(extra: ParamMap): Estimator[M] = { - super.copy(extra).asInstanceOf[Estimator[M]] - } + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 186bf7ae7a2f6..252acc156583f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer { /** Indicates whether this [[Model]] has a corresponding parent. */ def hasParent: Boolean = parent != null - override def copy(extra: ParamMap): M = { - // The default implementation of Params.copy doesn't work for models. - throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") - } + override def copy(extra: ParamMap): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 11a4722722ea1..a1f3851d804ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,6 +17,9 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging @@ -63,9 +66,7 @@ abstract class PipelineStage extends Params with Logging { outputSchema } - override def copy(extra: ParamMap): PipelineStage = { - super.copy(extra).asInstanceOf[PipelineStage] - } + override def copy(extra: ParamMap): PipelineStage } /** @@ -175,6 +176,11 @@ class PipelineModel private[ml] ( val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) @@ -190,6 +196,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages) + new PipelineModel(uid, stages.map(_.copy(extra))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e752b81a14282..333b42711ec52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -90,9 +90,7 @@ abstract class Predictor[ copyValues(train(dataset).setParent(this)) } - override def copy(extra: ParamMap): Learner = { - super.copy(extra).asInstanceOf[Learner] - } + override def copy(extra: ParamMap): Learner /** * Train a model using the given dataset and parameters. @@ -124,9 +122,7 @@ abstract class Predictor[ */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } + .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } } @@ -173,7 +169,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index f07f733a5ddb5..3c7bcf7590e6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage { */ def transform(dataset: DataFrame): DataFrame - override def copy(extra: ParamMap): Transformer = { - super.copy(extra).asInstanceOf[Transformer] - } + override def copy(extra: ParamMap): Transformer } /** @@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] dataset.withColumn($(outputCol), callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } + + override def copy(extra: ParamMap): T = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index ce43a450daad0..e479f169021d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} /** * :: DeveloperApi :: @@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory { * Creates an [[Attribute]] from a [[StructField]] instance. */ def fromStructField(field: StructField): Attribute = { - require(field.dataType == DoubleType) + require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 263d580fe2dd3..85c097bc64a4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils @@ -101,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur var outputData = dataset var numColsOutput = 0 if (getRawPredictionCol != "") { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if (getPredictionCol != "") { val predUDF = if (getRawPredictionCol != "") { - callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) + udf(raw2prediction _).apply(col(getRawPredictionCol)) } else { - callUDF(predict _, DoubleType, col(getFeaturesCol)) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col(getFeaturesCol)) } outputData = outputData.withColumn(getPredictionCol, predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8030e0728a56c..2dc1824964a42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62f4b51f770e9..554e3b8e052b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f136bcee9cf2b..2e6eedd45ab07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String) new LogisticRegressionModel(uid, weights.compressed, intercept) } + + override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 825f9ed1b54b2..ea757c5e40c76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -24,7 +24,7 @@ import scala.language.existentials import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @@ -88,9 +88,9 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString - val init: () => Map[Int, Double] = () => {Map()} + val initUDF = udf { () => Map[Int, Double]() } val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) - val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) + val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -106,13 +106,12 @@ final class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = - (predictions: Map[Int, Double], prediction: Vector) => { - predictions + ((index, prediction(1))) - } - val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) + val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => + predictions + ((index, prediction(1))) + } val transformedDataset = model.transform(df).select(columns : _*) - val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) + val updatedDataset = transformedDataset + .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column @@ -124,15 +123,21 @@ final class OneVsRestModel private[ml] ( } // output the index of the classifier with highest confidence as prediction - val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { + val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } // output label and label metadata as prediction - val labelUdf = callUDF(label, DoubleType, col(accColName)) - aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + aggregatedDataset + .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) .drop(accColName) } + + override def copy(extra: ParamMap): OneVsRestModel = { + val copied = new OneVsRestModel( + uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) + copyValues(copied, extra) + } } /** @@ -179,17 +184,15 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - - val label: Double => Double = (label: Double) => { + val labelUDF = udf { (label: Double) => if (label.toInt == index) 1.0 else 0.0 } // generate new label metadata for the binary problem. // TODO: use when ... otherwise after SPARK-7321 is merged - val labelUDF = callUDF(label, DoubleType, col($(labelCol))) val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) + val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) @@ -209,4 +212,12 @@ final class OneVsRest(override val uid: String) val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) copyValues(model) } + + override def copy(extra: ParamMap): OneVsRest = { + val copied = defaultCopy(extra).asInstanceOf[OneVsRest] + if (isDefined(classifier)) { + copied.setClassifier($(classifier).copy(extra)) + } + copied + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 330ae2938f4e0..38e832372698c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -98,26 +98,34 @@ private[spark] abstract class ProbabilisticClassificationModel[ var outputData = dataset var numColsOutput = 0 if ($(rawPredictionCol).nonEmpty) { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if ($(probabilityCol).nonEmpty) { val probUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol))) + udf(raw2probability _).apply(col($(rawPredictionCol))) } else { - callUDF(predictProbability _, new VectorUDT, col($(featuresCol))) + val probabilityUDF = udf { (features: Any) => + predictProbability(features.asInstanceOf[FeaturesType]) + } + probabilityUDF(col($(featuresCol))) } outputData = outputData.withColumn($(probabilityCol), probUDF) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { val predUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + udf(raw2prediction _).apply(col($(rawPredictionCol))) } else if ($(probabilityCol).nonEmpty) { - callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + udf(probability2prediction _).apply(col($(probabilityCol))) } else { - callUDF(predict _, DoubleType, col($(featuresCol))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col($(featuresCol))) } outputData = outputData.withColumn($(predictionCol), predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 852a67e066322..d3c67494a31e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f695ddaeefc72..4a82b77f0edcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String) metrics.unpersist() metric } + + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 61e937e693699..e56c946a063e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -46,7 +46,5 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double - override def copy(extra: ParamMap): Evaluator = { - super.copy(extra).asInstanceOf[Evaluator] - } + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index abb1b35bedea5..01c000b47514c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String) /** * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * + * Because we will maximize evaluation value (ref: `CrossValidator`), + * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + * we take and output the negative of this metric. * @group param */ val metricName: Param[String] = { @@ -70,14 +74,16 @@ final class RegressionEvaluator(override val uid: String) val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { case "rmse" => - metrics.rootMeanSquaredError + -metrics.rootMeanSquaredError case "mse" => - metrics.meanSquaredError + -metrics.meanSquaredError case "r2" => metrics.r2 case "mae" => - metrics.meanAbsoluteError + -metrics.meanAbsoluteError } metric } + + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index b06122d733853..46314854d5e3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -83,4 +83,6 @@ final class Binarizer(override val uid: String) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } + + override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index a3d1f6f65ccaf..67e4785bc3553 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String) SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } + + override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1e758cb775de7..a359cb8f37ec3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index f936aef80f8af..319d23e46cef4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature @@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } + + override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 376b84530cd57..ecde80810580c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** @group getParam */ def getMinDocFreq: Int = $(minDocFreq) - /** @group setParam */ - def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - /** * Validate and transform the input schema. */ @@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDF = defaultCopy(extra) } /** @@ -109,4 +111,9 @@ class IDFModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDFModel = { + val copied = new IDFModel(uid, idfModel) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala new file mode 100644 index 0000000000000..8de10eb51f923 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +/** + * :: Experimental :: + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + */ +@Experimental +class NGram(override val uid: String) + extends UnaryTransformer[Seq[String], Seq[String], NGram] { + + def this() = this(Identifiable.randomUID("ngram")) + + /** + * Minimum n-gram length, >= 1. + * Default: 2, bigram features + * @group param + */ + val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", + ParamValidators.gtEq(1)) + + /** @group setParam */ + def setN(value: Int): this.type = set(n, value) + + /** @group getParam */ + def getN: Int = $(n) + + setDefault(n -> 2) + + override protected def createTransformFunc: Seq[String] => Seq[String] = { + _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 8f34878c8d329..3825942795645 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) } + + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 442e95820217a..d85e468562d4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index b0fd06d84fdb3..ca3c1cfb56b7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } /** @@ -125,4 +127,9 @@ class StandardScalerModel private[ml] ( val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScalerModel = { + val copied = new StandardScalerModel(uid, scaler) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f4e250757560a..bf7be363b8224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } /** @@ -144,4 +146,9 @@ class StringIndexerModel private[ml] ( schema } } + + override def copy(extra: ParamMap): StringIndexerModel = { + val copied = new StringIndexerModel(uid, labels) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 21c15b6c33f6c..5f9f57a2ebcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } /** @@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 229ee27ec5942..9f83c2ee16178 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String) } StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) } + + override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } private object VectorAssembler { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 1d0f23b4fb3db..c73bdccdef5fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -25,12 +25,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod SchemaUtils.checkColumnType(schema, $(inputCol), dataType) SchemaUtils.appendColumn(schema, $(outputCol), dataType) } + + override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } private object VectorIndexer { @@ -337,7 +339,8 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) - val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol))) + val transformUDF = udf { (vector: Vector) => transformFunc(vector) } + val newCol = transformUDF(dataset($(inputCol))) dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) } @@ -399,4 +402,9 @@ class VectorIndexerModel private[ml] ( val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) newAttributeGroup.toStructField() } + + override def copy(extra: ParamMap): VectorIndexerModel = { + val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 36f19509f0cfb..6ea6590956300 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } /** @@ -180,4 +182,9 @@ class Word2VecModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ba94d6a3a80a9..50c0d855066f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array /** * :: Experimental :: - * A param amd its value. + * A param and its value. */ @Experimental case class ParamPair[T](param: Param[T], value: T) { @@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable { /** * Creates a copy of this instance with the same UID and some extra params. - * The default implementation tries to create a new instance with the same UID. + * Subclasses should implement this method and set the return type properly. + * + * @see [[defaultCopy()]] + */ + def copy(extra: ParamMap): Params + + /** + * Default implementation of copy with extra params. + * It tries to create a new instance with the same UID. * Then it copies the embedded and extra parameters over and returns the new instance. - * Subclasses should override this method if the default approach is not sufficient. */ - def copy(extra: ParamMap): Params = { + protected final def defaultCopy[T <: Params](extra: ParamMap): T = { val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) - copyValues(that, extra) + copyValues(that, extra).asInstanceOf[T] } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8ffbcf0d8bc71..b0a6af171c01f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,6 +53,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Boolean]("standardization", "whether to standardize the training features" + + " prior to fitting the model sequence. Note that the coefficients of models are" + + " always returned on the original scale.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a0c8ccdac9ad9..bbe08939b6d75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -233,6 +233,23 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * (private[ml]) Trait for shared param standardization (default: true). + */ +private[ml] trait HasStandardization extends Params { + + /** + * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.. + * @group param + */ + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.") + + setDefault(standardization, true) + + /** @group getParam */ + final def getStandardization: Boolean = $(standardization) +} + /** * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index df009d855ecbb..2e44cd4cc6a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -216,6 +216,11 @@ class ALSModel private[ml] ( SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } + + override def copy(extra: ParamMap): ALSModel = { + val copied = new ALSModel(uid, rank, userFactors, itemFactors) + copyValues(copied, extra) + } } @@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): ALS = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 43b68e7bb20fa..be1f8063d41d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b7e374bb6cb49..036e3acb07412 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70cd8e9e87fae..1b1d7299fb496 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -26,7 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams - with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept /** * :: Experimental :: @@ -72,6 +73,14 @@ class LinearRegression(override val uid: String) def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + /** * Set the ElasticNet mixing parameter. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -123,6 +132,7 @@ class LinearRegression(override val uid: String) val numFeatures = summarizer.mean.size val yMean = statCounter.mean val yStd = math.sqrt(statCounter.variance) + // look at glmnet5.m L761 maaaybe that has info // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -142,7 +152,7 @@ class LinearRegression(override val uid: String) val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), featuresStd, featuresMean, effectiveL2RegParam) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { @@ -180,12 +190,14 @@ class LinearRegression(override val uid: String) // The intercept in R's GLMNET is computed using closed form after the coefficients are // converged. See the following discussion for detail. // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) + val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0 if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } + + override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } /** @@ -232,6 +244,7 @@ class LinearRegressionModel private[ml] ( * See this discussion for detail. * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet * + * When training with intercept enabled, * The objective function in the scaled space is given by * {{{ * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, @@ -239,6 +252,10 @@ class LinearRegressionModel private[ml] ( * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead + * of the respective means. + * * This can be rewritten as * {{{ * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} @@ -253,6 +270,7 @@ class LinearRegressionModel private[ml] ( * \sum_i w_i^\prime x_i - y / \hat{y} + offset * }}} * + * * Note that the effective weights and offset don't depend on training dataset, * so they can be precomputed. * @@ -299,6 +317,7 @@ private class LeastSquaresAggregator( weights: Vector, labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { @@ -319,7 +338,7 @@ private class LeastSquaresAggregator( } i += 1 } - (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) } private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) @@ -402,6 +421,7 @@ private class LeastSquaresCostFun( data: RDD[(Double, Vector)], labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { @@ -410,7 +430,7 @@ private class LeastSquaresCostFun( val w = Vectors.fromBreeze(weights) val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, - labelMean, featuresStd, featuresMean))( + labelMean, fitIntercept, featuresStd, featuresMean))( seqOp = (c, v) => (c, v) match { case (aggregator, (label, features)) => aggregator.add(label, features) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 49a1f7ce8c995..21c59061a02fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cb29392e8bc63..e2444ab65b43b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM est.copy(paramMap).validateParams() } } + + override def copy(extra: ParamMap): CrossValidator = { + val copied = defaultCopy(extra).asInstanceOf[CrossValidator] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala new file mode 100644 index 0000000000000..bc6041b221732 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel + +/** + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python + */ +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel) + extends PowerIterationClusteringModel(model.k, model.assignments) { + + def getAssignments: RDD[Array[Any]] = { + model.assignments.map(x => Array(x.id, x.cluster)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8f66bc808a007..a66a404d5c846 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel @@ -277,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes( + def trainNaiveBayesModel( data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) @@ -345,7 +346,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for Python mllib GaussianMixture.run() * Returns a list containing weights, mean and covariance of each mixture component. */ - def trainGaussianMixture( + def trainGaussianMixtureModel( data: JavaRDD[Vector], k: Int, convergenceTol: Double, @@ -405,6 +406,33 @@ private[python] class PythonMLLibAPI extends Serializable { model.predictSoft(data).map(Vectors.dense) } + /** + * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a + * handle to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see the + * Py4J documentation. + * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix. + * @param k number of clusters. + * @param maxIterations maximum number of iterations of the power iteration loop. + * @param initMode the initialization mode. This can be either "random" to use + * a random vector as vertex properties, or "degree" to use + * normalized sum similarities. Default: random. + */ + def trainPowerIterationClusteringModel( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + initMode: String): PowerIterationClusteringModel = { + + val pic = new PowerIterationClustering() + .setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initMode) + + val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2)))) + new PowerIterationClusteringModelWrapper(model) + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -519,6 +547,16 @@ private[python] class PythonMLLibAPI extends Serializable { new ChiSqSelector(numTopFeatures).fit(data.rdd) } + /** + * Java stub for PCA.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitPCA(k: Int, data: JavaRDD[Vector]): PCAModel = { + new PCA(k).fit(data.rdd) + } + /** * Java stub for IDF.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -542,7 +580,7 @@ private[python] class PythonMLLibAPI extends Serializable { * @param seed initial seed for random generator * @return A handle to java Word2VecModelWrapper instance at python side */ - def trainWord2Vec( + def trainWord2VecModel( dataJRDD: JavaRDD[java.util.ArrayList[String]], vectorSize: Int, learningRate: Double, @@ -686,12 +724,14 @@ private[python] class PythonMLLibAPI extends Serializable { lossStr: String, numIterations: Int, learningRate: Double, - maxDepth: Int): GradientBoostedTreesModel = { + maxDepth: Int, + maxBins: Int): GradientBoostedTreesModel = { val boostingStrategy = BoostingStrategy.defaultParams(algoStr) boostingStrategy.setLoss(Losses.fromString(lossStr)) boostingStrategy.setNumIterations(numIterations) boostingStrategy.setLearningRate(learningRate) boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.setMaxBins(maxBins) boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) @@ -702,6 +742,14 @@ private[python] class PythonMLLibAPI extends Serializable { } } + def elementwiseProductVector(scalingVector: Vector, vector: Vector): Vector = { + new ElementwiseProduct(scalingVector).transform(vector) + } + + def elementwiseProductVector(scalingVector: Vector, vector: JavaRDD[Vector]): JavaRDD[Vector] = { + new ElementwiseProduct(scalingVector).transform(vector) + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. @@ -952,10 +1000,54 @@ private[python] class PythonMLLibAPI extends Serializable { def estimateKernelDensity( sample: JavaRDD[Double], bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { - return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( points.asScala.toArray) } + /** + * Java stub for the update method of StreamingKMeansModel. + */ + def updateStreamingKMeansModel( + clusterCenters: JList[Vector], + clusterWeights: JList[Double], + data: JavaRDD[Vector], + decayFactor: Double, + timeUnit: String): JList[Object] = { + val model = new StreamingKMeansModel( + clusterCenters.asScala.toArray, clusterWeights.asScala.toArray) + .update(data, decayFactor, timeUnit) + List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava + } + + /** + * Wrapper around the generateLinearInput method of LinearDataGenerator. + */ + def generateLinearInputWrapper( + intercept: Double, + weights: JList[Double], + xMean: JList[Double], + xVariance: JList[Double], + nPoints: Int, + seed: Int, + eps: Double): Array[LabeledPoint] = { + LinearDataGenerator.generateLinearInput( + intercept, weights.asScala.toArray, xMean.asScala.toArray, + xVariance.asScala.toArray, nPoints, seed, eps).toArray + } + + /** + * Wrapper around the generateLinearRDD method of LinearDataGenerator. + */ + def generateLinearRDDWrapper( + sc: JavaSparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int, + intercept: Double): JavaRDD[LabeledPoint] = { + LinearDataGenerator.generateLinearRDD( + sc, nexamples, nfeatures, eps, nparts, intercept) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index efbfeb4059f5a..3fab7ea79befc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -159,7 +159,7 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[mllib] (val idf: Vector) extends Serializable { +class IDFModel private[spark] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 51546d41c36a6..f087d06d2a46a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -431,7 +431,7 @@ class Word2Vec extends Serializable with Logging { * Word2Vec model */ @Experimental -class Word2VecModel private[mllib] ( +class Word2VecModel private[spark] ( model: Map[String, Array[Float]]) extends Serializable with Saveable { // wordList: Ordered list of words obtained from model. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index efa8459d3cdba..abac08022ea47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental @@ -62,13 +62,14 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex @Experimental class FPGrowth private ( private var minSupport: Double, - private var numPartitions: Int) extends Logging with Serializable { + private var numPartitions: Int, + private var ordered: Boolean) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data}. + * as the input data, ordered: `false`}. */ - def this() = this(0.3, -1) + def this() = this(0.3, -1, false) /** * Sets the minimal support level (default: `0.3`). @@ -86,6 +87,15 @@ class FPGrowth private ( this } + /** + * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine + * itemsets). + */ + def setOrdered(ordered: Boolean): this.type = { + this.ordered = ordered + this + } + /** * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction @@ -155,7 +165,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered) } } @@ -171,9 +181,12 @@ class FPGrowth private ( itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] - // Filter the basket by frequent items pattern and sort their ranks. + // Filter the basket by frequent items pattern val filtered = transaction.flatMap(itemToRank.get) - ju.Arrays.sort(filtered) + if (!this.ordered) { + ju.Arrays.sort(filtered) + } + // Generate conditional transactions val n = filtered.length var i = n - 1 while (i >= 0) { @@ -198,9 +211,18 @@ object FPGrowth { * Frequent itemset. * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. * @param freq frequency + * @param ordered indicates if items represents an itemset (false) or sequence (true) * @tparam Item item type */ - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean) + extends Serializable { + + /** + * Auxillary constructor, assumes unordered by default. + */ + def this(items: Array[Item], freq: Long) { + this(items, freq, false) + } /** * Returns items in a Java List. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 557119f7b1cd1..3523f1804325d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -213,9 +213,9 @@ private[spark] object BLAS extends Serializable with Logging { def scal(a: Double, x: Vector): Unit = { x match { case sx: SparseVector => - f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + f2jBLAS.dscal(sx.values.length, a, sx.values, 1) case dx: DenseVector => - f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + f2jBLAS.dscal(dx.values.length, a, dx.values, 1) case _ => throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 9584da8e3a0f9..85e63b1382b5e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -197,6 +197,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def typeName: String = "matrix" + override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT" + private[spark] override def asNullable: MatrixUDT = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c973b..06e45e10c5bf4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -179,7 +179,7 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 for (i <- 1 to numIterations) { val bcWeights = data.context.broadcast(weights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 26be30ff9d6fd..6709bd79bc820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index e0c03d8180c7a..7d28ffad45c92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -73,7 +73,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. - * This solves the l1-regularized least squares regression formulation + * This solves the l2-regularized least squares regression formulation * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cea8f3f47307b..141052ba813ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,21 +83,15 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - val initialWeights = - model match { - case Some(m) => - m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") } - model = Some(algorithm.run(rdd, initialWeights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index a49153bf73c0d..235e043c7754b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** Set the initial weights. */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 52d6468a72af7..7c5cfa7bd84ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -270,12 +270,28 @@ object MLUtils { * Returns a new vector with `1.0` (bias) appended to the input vector. */ def appendBias(vector: Vector): Vector = { - val vector1 = vector.toBreeze match { - case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(1.0))) - case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(1.0), 1)) - case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + vector match { + case dv: DenseVector => + val inputValues = dv.values + val inputLength = inputValues.length + val outputValues = Array.ofDim[Double](inputLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputLength) + outputValues(inputLength) = 1.0 + Vectors.dense(outputValues) + case sv: SparseVector => + val inputValues = sv.values + val inputIndices = sv.indices + val inputValuesLength = inputValues.length + val dim = sv.size + val outputValues = Array.ofDim[Double](inputValuesLength + 1) + val outputIndices = Array.ofDim[Int](inputValuesLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) + System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) + outputValues(inputValuesLength) = 1.0 + outputIndices(inputValuesLength) = dim + Vectors.sparse(dim + 1, outputIndices, outputValues) + case _ => throw new IllegalArgumentException(s"Do not support vector type ${vector.getClass}") } - Vectors.fromBreeze(vector1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index 308f7f3578e21..a841c5caf0142 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -98,6 +98,8 @@ private[mllib] object NumericParser { } } else if (token == ")") { parsing = false + } else if (token.trim.isEmpty){ + // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number items.append(parseDouble(token)) diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index ff5929235ac2c..3ae09d39ef500 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -102,4 +102,9 @@ private void init() { setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } + + @Override + public JavaTestParams copy(ParamMap extra) { + return defaultCopy(extra); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 05bf58e63abaf..63d2fa31c7499 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame @@ -81,4 +84,28 @@ class PipelineSuite extends SparkFunSuite { pipeline.fit(dataset) } } + + test("PipelineModel.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) + require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 72b575d022547..c5fd2f9d5a22a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite { assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute) val fldWithMeta = new StructField("x", DoubleType, false, metadata) assert(Attribute.fromStructField(fldWithMeta).isNumeric) + // Attribute.fromStructField should accept any NumericType, not just DoubleType + val longFldWithMeta = new StructField("x", LongType, false, metadata) + assert(Attribute.fromStructField(longFldWithMeta).isNumeric) + val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index ae40b0b8ff854..73b4805c4c597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, - DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) } + test("params") { + ParamsSuite.checkParams(new DecisionTreeClassifier) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + ParamsSuite.checkParams(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1302da3c373ff..82c345491bb3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } + test("params") { + ParamsSuite.checkParams(new GBTClassifier) + val model = new GBTClassificationModel("gbtc", + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(1.0)) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a755cac3ea76e..bc6eeac1db5da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -35,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /** - * Here is the instruction describing how to export the test data into CSV format - * so we can validate the training accuracy compared with R's glmnet package. - * - * import org.apache.spark.mllib.classification.LogisticRegressionSuite - * val nPoints = 10000 - * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - * val xMean = Array(5.843, 3.057, 3.758, 1.199) - * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - * weights, xMean, xVariance, true, nPoints, 42), 1) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 @@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new LogisticRegression) + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + test("logistic regression: default params") { val lr = new LogisticRegression assert(lr.getLabelCol === "label") @@ -204,22 +211,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(true) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 2.8366423 - * data.V2 -0.5895848 - * data.V3 0.8931147 - * data.V4 -0.3925051 - * data.V5 -0.7996864 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 */ val interceptR = 2.8366423 val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) @@ -235,23 +242,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(false) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = - * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.3534996 - * data.V3 1.2964482 - * data.V4 -0.3571741 - * data.V5 -0.7407946 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 */ val interceptR = 0.0 val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) @@ -268,22 +275,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.05627428 - * data.V2 . - * data.V3 . - * data.V4 -0.04325749 - * data.V5 -0.02481551 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 */ val interceptR = -0.05627428 val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551) @@ -300,23 +307,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 . - * data.V3 . - * data.V4 -0.05189203 - * data.V5 -0.03891782 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 */ val interceptR = 0.0 val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782) @@ -333,22 +340,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.15021751 - * data.V2 -0.07251837 - * data.V3 0.10724191 - * data.V4 -0.04865309 - * data.V5 -0.10062872 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 */ val interceptR = 0.15021751 val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) @@ -365,23 +372,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.06099165 - * data.V3 0.12857058 - * data.V4 -0.04708770 - * data.V5 -0.09799775 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 */ val interceptR = 0.0 val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) @@ -398,22 +405,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.57734851 - * data.V2 -0.05310287 - * data.V3 . - * data.V4 -0.08849250 - * data.V5 -0.15458796 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 */ val interceptR = 0.57734851 val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) @@ -430,23 +437,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.001005743 - * data.V3 0.072577857 - * data.V4 -0.081203769 - * data.V5 -0.142534158 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 */ val interceptR = 0.0 val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) @@ -473,16 +480,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { classSummarizer1.merge(classSummarizer2) }).histogram - /** - * For binary logistic regression with strong L1 regularization, all the weights will be zeros. - * As a result, - * {{{ - * P(0) = 1 / (1 + \exp(b)), and - * P(1) = \exp(b) / (1 + \exp(b)) - * }}}, hence - * {{{ - * b = \log{P(1) / P(0)} = \log{count_1 / count_0} - * }}} + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) @@ -493,22 +500,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6) assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.2480643 - * data.V2 0.0000000 - * data.V3 . - * data.V4 . - * data.V5 . + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . */ val interceptR = -0.248065 val weightsR = Array(0.0, 0.0, 0.0, 0.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 1d04ccb509057..75cf5bd4ead4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,15 +19,18 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(rdd) } + test("params") { + ParamsSuite.checkParams(new OneVsRest) + val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0) + val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel)) + ParamsSuite.checkParams(model) + } + test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() @@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val output = ovr.fit(dataset).transform(dataset) assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + + test("OneVsRest.copy and OneVsRestModel.copy") { + val lr = new LogisticRegression() + .setMaxIter(1) + + val ovr = new OneVsRest() + withClue("copy with classifier unset should work") { + ovr.copy(ParamMap(lr.maxIter -> 10)) + } + ovr.setClassifier(lr) + val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10)) + require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects") + require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, + "copy should handle extra classifier params") + + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + ovrModel.models.foreach { case m: LogisticRegressionModel => + require(m.getThreshold === 0.1, "copy should handle extra model params") + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index eee9355a67be3..1b6b69c7dc71e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestClassifier]]. */ @@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) } + test("params") { + ParamsSuite.checkParams(new RandomForestClassifier) + val model = new RandomForestClassificationModel("rfc", + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala similarity index 52% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index f33a18d53b1a9..def869fe66777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -15,23 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.ml.param.ParamsSuite -/** - * Overrides our expression evaluation tests and reruns them after optimization has occured. This - * is to ensure that constant folding and other optimizations do not break anything. - */ -class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) - super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) +class BinaryClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new BinaryClassificationEvaluator) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 36a1ac6b7996d..5b203784559e2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RegressionEvaluator) + } + test("Regression Evaluator: default params") { /** * Here is the instruction describing how to export the test data into CSV format @@ -58,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) // r2 score evaluator.setMetricName("r2") @@ -66,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 7953bd0417191..2086043983661 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) } + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 507a8a7db24c7..ec85e0d151e07 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Bucketizer) + } + test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. val splits = Array(-0.5, 0.0, 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 7b2d70e644005..4157b84b29d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -28,8 +28,7 @@ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { - val hashingTF = new HashingTF - ParamsSuite.checkParams(hashingTF, 3) + ParamsSuite.checkParams(new HashingTF) } test("hashingTF") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index d83772e8be755..08f80af03429b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new IDF) + val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0))) + ParamsSuite.checkParams(model) + } + test("compute IDF with default parameter") { val numOfFeatures = 4 val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala new file mode 100644 index 0000000000000..ab97e3dbc6ee0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) + +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.NGramSuite._ + + test("default behavior yields bigram features") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + ))) + testNGram(nGram, dataset) + } + + test("NGramLength=4 yields length 4 n-grams") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + ))) + testNGram(nGram, dataset) + } + + test("empty input yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array(), + Array() + ))) + testNGram(nGram, dataset) + } + + test("input array < n yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(6) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + ))) + testNGram(nGram, dataset) + } +} + +object NGramSuite extends SparkFunSuite { + + def testNGram(t: NGram, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("nGrams", "wantedNGrams") + .collect() + .foreach { case Row(actualNGrams, wantedNGrams) => + assert(actualNGrams === wantedNGrams) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 2e5036a844562..65846a846b7b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame @@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { indexer.transform(df) } + test("params") { + ParamsSuite.checkParams(new OneHotEncoder) + } + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index feca866cd711d..29eebd8960ebc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite @@ -27,6 +28,10 @@ import org.apache.spark.sql.Row class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new PolynomialExpansion) + } + test("Polynomial expansion with default parameter") { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5f557e16e5150..99f82bea42688 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -19,10 +19,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new StringIndexer) + val model = new StringIndexerModel("indexer", Array("a", "b")) + ParamsSuite.checkParams(model) + } + test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index ac279cb3215c2..e5fd21c3f6fca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -20,15 +20,27 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +class TokenizerSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new Tokenizer) + } +} + class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ + test("params") { + ParamsSuite.checkParams(new RegexTokenizer) + } + test("RegexTokenizer") { val tokenizer0 = new RegexTokenizer() .setGaps(false) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 489abb5af7130..bb4d5b983e0d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new VectorAssembler) + } + test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 06affc7305cf5..8c85c96d5c6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { private def getIndexer: VectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexed") + test("params") { + ParamsSuite.checkParams(new VectorIndexer) + val model = new VectorIndexerModel("indexer", 1, Map.empty) + ParamsSuite.checkParams(model) + } + test("Cannot fit an empty DataFrame") { val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 94ebc3aebfa37..aa6ce533fd885 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -18,13 +18,21 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Word2Vec) + val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f)))) + ParamsSuite.checkParams(model) + } + test("Word2Vec") { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 96094d7a099aa..050d4170ea017 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered - * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as - * the param method name. + * Checks common requirements for [[Params.params]]: + * - params are ordered by names + * - param parent has the same UID as the object's UID + * - param name is the same as the param method name + * - obj.copy should return the same type as the obj */ - def checkParams(obj: Params, expectedNumParams: Int): Unit = { + def checkParams(obj: Params): Unit = { + val clazz = obj.getClass + val params = obj.params - require(params.length === expectedNumParams, - s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.") val paramNames = params.map(_.name) - require(paramNames === paramNames.sorted) + require(paramNames === paramNames.sorted, "params must be ordered by names") params.foreach { p => assert(p.parent === obj.uid) assert(obj.getParam(p.name) === p) + // TODO: Check that setters return self, which needs special handling for generic types. } + + val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) + val copyReturnType = copyMethod.getReturnType + require(copyReturnType === obj.getClass, + s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index a9e78366ad98f..2759248344531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H require(isDefined(inputCol)) } - override def copy(extra: ParamMap): TestParams = { - super.copy(extra).asInstanceOf[TestParams] - } + override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index eb5408d3fee7c..b3af81a3c60b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -18,13 +18,15 @@ package org.apache.spark.ml.param.shared import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.param.Params +import org.apache.spark.ml.param.{ParamMap, Params} class SharedParamsSuite extends SparkFunSuite { test("outputCol") { - class Obj(override val uid: String) extends Params with HasOutputCol + class Obj(override val uid: String) extends Params with HasOutputCol { + override def copy(extra: ParamMap): Obj = defaultCopy(extra) + } val obj = new Obj("obj") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 732e2c42be144..5f39d44f37352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -26,42 +26,53 @@ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ + @transient var datasetWithoutIntercept: DataFrame = _ - /** - * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - * is the same as the one trained by R's glmnet package. The following instruction - * describes how to reproduce the data in R. - * - * import org.apache.spark.mllib.util.LinearDataGenerator - * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. + + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept + */ + datasetWithoutIntercept = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + } test("linear regression with intercept without regularization") { val trainer = new LinearRegression val model = trainer.fit(dataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * library("glmnet") - * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - * label <- as.numeric(data$V1) - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.300528 - * as.numeric.data.V2. 4.701024 - * as.numeric.data.V3. 7.198257 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.300528 + as.numeric.data.V2. 4.701024 + as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 val weightsR = Array(4.700706, 7.199082) @@ -78,20 +89,56 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept without regularization") { + val trainer = (new LinearRegression).setFitIntercept(false) + val model = trainer.fit(dataset) + val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 + */ + val weightsR = Array(6.995908, 5.275131) + + assert(model.intercept ~== 0 relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 + */ + val weightsWithoutInterceptR = Array(4.70011, 7.19943) + + assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) + assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) + assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + } + test("linear regression with intercept with L1 regularization") { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.311546 - * as.numeric.data.V2. 2.123522 - * as.numeric.data.V3. 4.605651 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 */ - val interceptR = 6.243000 + val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) @@ -106,18 +153,48 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 + */ + val interceptR = 0.0 + val weightsR = Array(6.299752, 4.772913) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with L2 regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.328062 - * as.numeric.data.V2. 3.222034 - * as.numeric.data.V3. 4.926260 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.328062 + as.numeric.data.V2. 3.222034 + as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 val weightsR = Array(3.736216, 5.712356) @@ -134,18 +211,48 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 + */ + val interceptR = 0.0 + val weightsR = Array(5.522875, 4.214502) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with ElasticNet regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.324108 - * as.numeric.data.V2. 3.168435 - * as.numeric.data.V3. 5.200403 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 val weightsR = Array(3.670489, 6.001122) @@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression without intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 + */ + val interceptR = 0.0 + val weightsR = Array(5.673348, 4.322251) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 9b3619f0046ea..db64511a76055 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite - import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType @@ -59,6 +59,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(cvModel.avgMetrics.length === lrParamMaps.length) } + test("cross validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new CrossValidator() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.avgMetrics.length === lrParamMaps.length) + } + test("validateParams should check estimatorParamMaps") { import CrossValidatorSuite._ @@ -98,6 +128,8 @@ object CrossValidatorSuite { override def transformSchema(schema: StructType): StructType = { throw new UnsupportedOperationException } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) } class MyEvaluator extends Evaluator { @@ -107,5 +139,7 @@ object CrossValidatorSuite { } override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index e98b61e13e21f..fd653296c9d97 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert(error.head > 0.8 & error.last < 0.2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 66ae3543ecc4e..1a8a1e79f2810 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { - test("FP-Growth using String type") { + test("FP-Growth frequent itemsets using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -38,12 +38,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) + .setOrdered(false) .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { itemset => (itemset.items.toSet, itemset.freq) @@ -61,17 +63,59 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } - test("FP-Growth using Int type") { + test("FP-Growth frequent sequences using String type"){ + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model1 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .setOrdered(true) + .run(rdd) + + /* + Use the following R code to verify association rules using arulesSequences package. + + data = read_baskets("path", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(data, parameter = list(support = 0.5)) + resSeq = as(freqItemSeq, "data.frame") + resSeq$support = resSeq$support * length(transactions) + names(resSeq)[names(resSeq) == "support"] = "freq" + resSeq + */ + val expected = Set( + (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L), + (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L), + (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L) + ) + val freqItemseqs1 = model1.freqItemsets.collect().map { itemset => + (itemset.items.toSeq, itemset.freq) + }.toSet + assert(freqItemseqs1 == expected) + } + + test("FP-Growth frequent itemsets using Int type") { val transactions = Seq( "1 2 3", "1 2 3 4", @@ -88,12 +132,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) + .setOrdered(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -109,12 +155,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 65) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d8364a06de4da..f8d0af8820e64 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -31,6 +31,11 @@ class LabeledPointSuite extends SparkFunSuite { } } + test("parse labeled points with whitespaces") { + val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") + assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) + } + test("parse labeled points with v0.9 format") { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 9a379406d5061..f5e2d31056cbd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert((error.head - error.last) > 2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index 8dcb9ba9be108..fa4f74d71b7e7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -37,4 +37,11 @@ class NumericParserSuite extends SparkFunSuite { } } } + + test("parser with whitespaces") { + val s = "(0.0, [1.0, 2.0])" + val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] + assert(parsed(0).asInstanceOf[Double] === 0.0) + assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) + } } diff --git a/network/common/pom.xml b/network/common/pom.xml index a85e0a66f4a30..7dc3068ab8cb7 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -77,7 +77,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 4b5bfcb6f04bc..532463e96fbb7 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -79,7 +79,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index dd08e24cade23..022ed88a16480 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -108,7 +108,8 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) + || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } else { throw new UnsupportedOperationException( diff --git a/pom.xml b/pom.xml index 67b6375f576d3..94dd512cfb618 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,6 @@ 2.10 ${scala.version} org.scala-lang - 3.6.3 1.9.13 2.4.4 1.1.1.7 @@ -179,7 +178,7 @@ compile ${session.executionRootDirectory} @@ -249,7 +248,7 @@ mapr-repo MapR Repository - http://repository.mapr.com/maven + http://repository.mapr.com/maven/ true @@ -587,7 +586,7 @@ io.netty netty-all - 4.0.23.Final + 4.0.28.Final org.apache.derby @@ -682,7 +681,7 @@ org.mockito - mockito-all + mockito-core 1.9.5 test @@ -748,6 +747,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -760,6 +763,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 @@ -1244,7 +1251,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m org.apache.maven.plugins @@ -1665,7 +1694,7 @@ hadoop-1 - 1.0.4 + 1.2.1 2.4.1 0.98.7-hadoop1 hadoop1 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 5812b72f0aa78..f16bf989f200b 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,8 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - // TODO: Change this once Spark 1.4.0 is released - val previousSparkVersion = "1.4.0-rc4" + val previousSparkVersion = "1.4.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a93ca2999510..6f86a505b3ae4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,14 +44,37 @@ object MimaExcludes { // JavaRDDLike is not meant to be extended by user programs ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.util.collection.PairIterator"), // Removing a testing method from a private class ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution") + excludePackage("org.apache.spark.sql.execution"), + // NanoTime and CatalystTimestampConverter is only used inside catalyst, + // not needed anymore + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.timestamp.NanoTime"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.timestamp.NanoTime$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.CatalystTimestampConverter"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), + // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo$") ) case v if v.startsWith("1.4") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ef3a175bac209..f5f1c9a1a247a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,8 @@ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings +import spray.revolver.RevolverPlugin._ + object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile @@ -51,6 +53,8 @@ object BuildCommons { // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation + + val testTempDir = s"$sparkHome/target/tmp" } object SparkBuild extends PomBuild { @@ -144,7 +148,9 @@ object SparkBuild extends PomBuild { javacOptions in (Compile, doc) ++= { val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty - } + }, + + javacOptions in Compile ++= Seq("-encoding", "UTF-8") ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { @@ -155,14 +161,13 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) + .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove launcher from this list after 1.4.0 allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -496,6 +501,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -508,7 +514,7 @@ object TestSettings { javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", - javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" + javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. @@ -518,6 +524,13 @@ object TestSettings { libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, + // Make sure the test temp directory exists. + resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => + if (!new File(testTempDir).isDirectory()) { + require(new File(testTempDir).mkdirs()) + } + Seq[File]() + }, concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), // Remove certain packages from Scaladoc scalacOptions in (Compile, doc) := Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index 75bd604a1b857..51820460ca1a0 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -29,6 +29,8 @@ addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") + libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index adca90ddaf397..6ef8cf53cc747 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -264,4 +264,6 @@ def _start_update_server(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 3de4615428bb6..663c9abe0881e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -115,4 +115,6 @@ def __reduce__(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 44d90f1437bc9..d7466729b8f36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -291,6 +291,21 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + (i.e. + in case of local spark app something like 'local-1433865536131' + in case of YARN something like 'application_1433865536131_34483' + ) + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + @property def startTime(self): """Return the epoch time when the Spark Context was started.""" @@ -324,6 +339,12 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None + def emptyRDD(self): + """ + Create an RDD that has no partitions or elements. + """ + return RDD(self._jsc.emptyRDD(), self, NoOpSerializer()) + def range(self, start, end=None, step=1, numSlices=None): """ Create a new RDD of int containing elements from `start` to `end` diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 4ef2afe03544f..b27e91a4cc251 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": - import doctest - print(doctest.testmod()) + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3cee4ea6e3a35..90cd342a6cf7f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -51,6 +51,8 @@ def launch_gateway(): on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = "--conf spark.ui.enabled=false " + submit_args command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index d8ddb78c6d639..595593a7f2cde 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -160,13 +160,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): ... >>> evaluator = RegressionEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) - 2.842... + -2.842... >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) - 2.649... + -2.649... """ - # a placeholder to make it appear in the generated doc + # Because we will maximize evaluation value (ref: `CrossValidator`), + # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + # we take and output the negative of this metric. metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (mse|rmse|r2|mae)") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddb33f427ac64..8804dace849b3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder', +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel'] @@ -265,6 +265,75 @@ class IDFModel(JavaModel): """ +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6adbf166f34a8..c151d21fd661a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -252,6 +252,17 @@ def test_idf(self): output = idf0m.transform(dataset) self.assertIsNotNone(output.head().idf) + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a70c664a71fdb..735d45ba03d27 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,6 +21,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -28,13 +29,14 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', - 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes', + 'StreamingLogisticRegressionWithSGD'] class LinearClassificationModel(LinearModel): """ - A private abstract class representing a multiclass classification model. - The categories are represented by int values: 0, 1, 2, etc. + A private abstract class representing a multiclass classification + model. The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): super(LinearClassificationModel, self).__init__(weights, intercept) @@ -44,10 +46,11 @@ def setThreshold(self, value): """ .. note:: Experimental - Sets the threshold that separates positive predictions from negative - predictions. An example with prediction score greater than or equal - to this threshold is identified as an positive, and negative otherwise. - It is used for binary classification only. + Sets the threshold that separates positive predictions from + negative predictions. An example with prediction score greater + than or equal to this threshold is identified as an positive, + and negative otherwise. It is used for binary classification + only. """ self._threshold = value @@ -56,8 +59,9 @@ def threshold(self): """ .. note:: Experimental - Returns the threshold (if any) used for converting raw prediction scores - into 0/1 predictions. It is used for binary classification only. + Returns the threshold (if any) used for converting raw + prediction scores into 0/1 predictions. It is used for + binary classification only. """ return self._threshold @@ -65,22 +69,35 @@ def clearThreshold(self): """ .. note:: Experimental - Clears the threshold so that `predict` will output raw prediction scores. - It is used for binary classification only. + Clears the threshold so that `predict` will output raw + prediction scores. It is used for binary classification only. """ self._threshold = None def predict(self, test): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ raise NotImplementedError class LogisticRegressionModel(LinearClassificationModel): - """A linear binary classification model derived from logistic regression. + """ + Classification model trained using Multinomial/Binary Logistic + Regression. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. (Only used + in Binary Logistic Regression. In Multinomial Logistic + Regression, the intercepts will not be a single value, + so the intercepts will be part of the weights.) + :param numFeatures: the dimension of the features. + :param numClasses: the number of possible outcomes for k classes + classification problem in Multinomial Logistic Regression. + By default, it is binary logistic regression so numClasses + will be set to 2. >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -120,8 +137,9 @@ class LogisticRegressionModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: 1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> multi_class_data = [ @@ -161,8 +179,8 @@ def numClasses(self): def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -225,16 +243,19 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -243,13 +264,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), @@ -267,12 +289,15 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -281,19 +306,21 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param corrections: The number of corrections used in the LBFGS - update (default: 10). - :param tolerance: The convergence tolerance of iterations for - L-BFGS (default: 1e-4). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param corrections: The number of corrections used in the + LBFGS update (default: 10). + :param tolerance: The convergence tolerance of iterations + for L-BFGS (default: 1e-4). :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) - :param numClasses: The number of classes (i.e., outcomes) a label can take - in Multinomial Logistic Regression (default: 2). + algorithm should validate data before + training. (default: True) + :param numClasses: The number of classes (i.e., outcomes) a + label can take in Multinomial Logistic + Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -323,7 +350,11 @@ def train(rdd, i): class SVMModel(LinearClassificationModel): - """A support vector machine. + """ + Model for Support Vector Machines (SVMs). + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -359,8 +390,9 @@ class SVMModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: -1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass """ @@ -370,8 +402,8 @@ def __init__(self, weights, intercept): def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -409,16 +441,19 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ Train a support vector machine on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regType: The type of regularizer used for training - our model. + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -427,13 +462,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), @@ -449,9 +485,11 @@ class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - Contains two parameters: - - pi: vector of logs of class priors (dimension C) - - theta: matrix of logs of class conditional probabilities (CxD) + :param labels: list of labels. + :param pi: log of class priors, whose dimension is C, + number of labels. + :param theta: log of class conditional probabilities, whose + dimension is C-by-D, where D is number of features. >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -481,8 +519,9 @@ class NaiveBayesModel(Saveable, Loader): >>> sameModel = NaiveBayesModel.load(sc, path) >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -493,7 +532,10 @@ def __init__(self, labels, pi, theta): self.theta = theta def predict(self, x): - """Return the most likely class for a data vector or an RDD of vectors""" + """ + Return the most likely class for a data vector + or an RDD of vectors + """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) @@ -523,24 +565,118 @@ class NaiveBayes(object): @classmethod def train(cls, data, lambda_=1.0): """ - Train a Naive Bayes model given an RDD of (label, features) vectors. + Train a Naive Bayes model given an RDD of (label, features) + vectors. - This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can - handle all kinds of discrete data. For example, by converting - documents into TF-IDF vectors, it can be used for document - classification. By making every vector a 0-1 vector, it can also be - used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which + can handle all kinds of discrete data. For example, by + converting documents into TF-IDF vectors, it can be used for + document classification. By making every vector a 0-1 vector, + it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + The input feature values must be nonnegative. :param data: RDD of LabeledPoint. - :param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter (default: 1.0). """ first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("`data` should be an RDD of LabeledPoint") - labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) + labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LogisticRegression with SGD on a stream of data. + + The weights obtained at the end of training a stream are used as initial + weights for the next stream. + + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Number of iterations run for each batch of data. + :param miniBatchFraction: Fraction of data on which SGD is run for each + iteration. + :param regParam: L2 Regularization parameter. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + self.stepSize = stepSize + self.numIterations = numIterations + self.regParam = regParam + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLogisticRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn. + """ + initialWeights = _convert_to_vector(initialWeights) + + # LogisticRegressionWithSGD does only binary classification. + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LogisticRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LogisticRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b55583f82223f..e3c8a24c4a751 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -21,16 +21,23 @@ if sys.version > '3': xrange = range -from numpy import array +from math import exp, log + +from numpy import array, random, tile + +from collections import namedtuple -from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector from pyspark.mllib.stat.distribution import MultivariateGaussian -from pyspark.mllib.util import Saveable, Loader, inherit_doc +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable +from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', + 'PowerIterationClusteringModel', 'PowerIterationClustering', + 'StreamingKMeans', 'StreamingKMeansModel'] @inherit_doc @@ -75,8 +82,9 @@ class KMeansModel(Saveable, Loader): >>> sameModel = KMeansModel.load(sc, path) >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -98,6 +106,9 @@ def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") + if isinstance(x, RDD): + return x.map(self.predict) + x = _convert_to_vector(x) for i in xrange(len(self.centers)): distance = x.squared_distance(self.centers[i]) @@ -257,16 +268,297 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed, initialModelWeights, - initialModelMu, initialModelSigma) + weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental + + Model produced by [[PowerIterationClustering]]. + + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), + ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> rdd = sc.parallelize(data, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> model.k + 2 + >>> sorted(model.assignments().collect()) + [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = PowerIterationClusteringModel.load(sc, path) + >>> sameModel.k + 2 + >>> sorted(sameModel.assignments().collect()) + [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + @property + def k(self): + """ + Returns the number of clusters. + """ + return self.call("k") + + def assignments(self): + """ + Returns the cluster assignments of this model. + """ + return self.call("getAssignments").map( + lambda x: (PowerIterationClustering.Assignment(*x))) + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) + return PowerIterationClusteringModel(wrapper) + + +class PowerIterationClustering(object): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm + developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + From the abstract: PIC finds a very low-dimensional embedding of a + dataset using truncated power iteration on a normalized pair-wise + similarity matrix of the data. + """ + + @classmethod + def train(cls, rdd, k, maxIterations=100, initMode="random"): + """ + :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. + The similarity s,,ij,, must be nonnegative. + This is a symmetric matrix and hence s,,ij,, = s,,ji,,. + For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. + Tuples with i = j are ignored, because we assume + s,,ij,, = 0.0. + :param k: Number of clusters. + :param maxIterations: Maximum number of iterations of the + PIC algorithm. + :param initMode: Initialization mode. + """ + model = callMLlibFunc("trainPowerIterationClusteringModel", + rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) + return PowerIterationClusteringModel(model) + + class Assignment(namedtuple("Assignment", ["id", "cluster"])): + """ + Represents an (id, cluster) tuple. + """ + + +class StreamingKMeansModel(KMeansModel): + """ + .. note:: Experimental + + Clustering model which can perform an online update of the centroids. + + The update formula for each centroid is given by + + * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) + * n_t+1 = n_t * a + m_t + + where + + * c_t: Centroid at the n_th iteration. + * n_t: Number of samples (or) weights associated with the centroid + at the n_th iteration. + * x_t: Centroid of the new data closest to c_t. + * m_t: Number of samples (or) weights of the new data closest to c_t + * c_t+1: New centroid. + * n_t+1: New number of weights. + * a: Decay Factor, which gives the forgetfulness. + + Note that if a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. + + :param clusterCenters: Initial cluster centers. + :param clusterWeights: List of weights assigned to each cluster. + + >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] + >>> initWeights = [1.0, 1.0] + >>> stkm = StreamingKMeansModel(initCenters, initWeights) + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], + ... [0.9, 0.9], [1.1, 1.1]]) + >>> stkm = stkm.update(data, 1.0, u"batches") + >>> stkm.centers + array([[ 0., 0.], + [ 1., 1.]]) + >>> stkm.predict([-0.1, -0.1]) + 0 + >>> stkm.predict([0.9, 0.9]) + 1 + >>> stkm.clusterWeights + [3.0, 3.0] + >>> decayFactor = 0.0 + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) + >>> stkm = stkm.update(data, 0.0, u"batches") + >>> stkm.centers + array([[ 0.2, 0.2], + [ 1.5, 1.5]]) + >>> stkm.clusterWeights + [1.0, 1.0] + >>> stkm.predict([0.2, 0.2]) + 0 + >>> stkm.predict([1.5, 1.5]) + 1 + """ + def __init__(self, clusterCenters, clusterWeights): + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) + self._clusterWeights = list(clusterWeights) + + @property + def clusterWeights(self): + """Return the cluster weights.""" + return self._clusterWeights + + @ignore_unicode_prefix + def update(self, data, decayFactor, timeUnit): + """Update the centroids, according to data + + :param data: Should be a RDD that represents the new data. + :param decayFactor: forgetfulness of the previous centroids. + :param timeUnit: Can be "batches" or "points". If points, then the + decay factor is raised to the power of number of new + points and if batches, it is used as it is. + """ + if not isinstance(data, RDD): + raise TypeError("Data should be of an RDD, got %s." % type(data)) + data = data.map(_convert_to_vector) + decayFactor = float(decayFactor) + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + vectorCenters = [_convert_to_vector(center) for center in self.centers] + updatedModel = callMLlibFunc( + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, + data, decayFactor, timeUnit) + self.centers = array(updatedModel[0]) + self._clusterWeights = list(updatedModel[1]) + return self + + +class StreamingKMeans(object): + """ + .. note:: Experimental + + Provides methods to set k, decayFactor, timeUnit to configure the + KMeans algorithm for fitting and predicting on incoming dstreams. + More details on how the centroids are updated are provided under the + docs of StreamingKMeansModel. + + :param k: int, number of clusters + :param decayFactor: float, forgetfulness of the previous centroids. + :param timeUnit: can be "batches" or "points". If points, then the + decayfactor is raised to the power of no. of new points. + """ + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): + self._k = k + self._decayFactor = decayFactor + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + self._timeUnit = timeUnit + self._model = None + + def latestModel(self): + """Return the latest model""" + return self._model + + def _validate(self, dstream): + if self._model is None: + raise ValueError( + "Initial centers should be set either by setInitialCenters " + "or setRandomCenters.") + if not isinstance(dstream, DStream): + raise TypeError( + "Expected dstream to be of type DStream, " + "got type %s" % type(dstream)) + + def setK(self, k): + """Set number of clusters.""" + self._k = k + return self + + def setDecayFactor(self, decayFactor): + """Set decay factor.""" + self._decayFactor = decayFactor + return self + + def setHalfLife(self, halfLife, timeUnit): + """ + Set number of batches after which the centroids of that + particular batch has half the weightage. + """ + self._timeUnit = timeUnit + self._decayFactor = exp(log(0.5) / halfLife) + return self + + def setInitialCenters(self, centers, weights): + """ + Set initial centers. Should be set before calling trainOn. + """ + self._model = StreamingKMeansModel(centers, weights) + return self + + def setRandomCenters(self, dim, weight, seed): + """ + Set the initial centres to be random samples from + a gaussian population with constant weights. + """ + rng = random.RandomState(seed) + clusterCenters = rng.randn(self._k, dim) + clusterWeights = tile(weight, self._k) + self._model = StreamingKMeansModel(clusterCenters, clusterWeights) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + self._model.update(rdd, self._decayFactor, self._timeUnit) + + dstream.foreachRDD(update) + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + Returns a transformed dstream object + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + Returns a transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + def _test(): import doctest - globs = globals().copy() + import pyspark.mllib.clustering + globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index da90554f41437..b5138773fd61b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -33,12 +33,13 @@ from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.linalg import ( + Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', - 'ChiSqSelector', 'ChiSqSelectorModel'] + 'ChiSqSelector', 'ChiSqSelectorModel', 'ElementwiseProduct'] class VectorTransformer(object): @@ -110,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -190,7 +200,7 @@ def fit(self, dataset): Computes the mean and variance and stores as a model to be used for later scaling. - :param data: The data used to compute the mean and variance + :param dataset: The data used to compute the mean and variance to build the transformation model. :return: a StandardScalarModel """ @@ -251,6 +261,41 @@ def fit(self, data): return ChiSqSelectorModel(jmodel) +class PCAModel(JavaVectorTransformer): + """ + Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + """ + + +class PCA(object): + """ + A feature transformer that projects vectors to a low-dimensional space using PCA. + + >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]), + ... Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]), + ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] + >>> model = PCA(2).fit(sc.parallelize(data)) + >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() + >>> pcArray[0] + 1.648... + >>> pcArray[1] + -4.013... + """ + def __init__(self, k): + """ + :param k: number of principal components. + """ + self.k = int(k) + + def fit(self, data): + """ + Computes a [[PCAModel]] that contains the principal components of the input vectors. + :param data: source vectors + """ + jmodel = callMLlibFunc("fitPCA", self.k, data) + return PCAModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental @@ -310,10 +355,6 @@ def transform(self, x): vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) def idf(self): @@ -513,13 +554,45 @@ def fit(self, data): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") - jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), int(self.seed), int(self.minCount)) return Word2VecModel(jmodel) +class ElementwiseProduct(VectorTransformer): + """ + .. note:: Experimental + + Scales each column of the vector, with the supplied weight vector. + i.e the elementwise product. + + >>> weight = Vectors.dense([1.0, 2.0, 3.0]) + >>> eprod = ElementwiseProduct(weight) + >>> a = Vectors.dense([2.0, 1.0, 3.0]) + >>> eprod.transform(a) + DenseVector([2.0, 2.0, 9.0]) + >>> b = Vectors.dense([9.0, 3.0, 4.0]) + >>> rdd = sc.parallelize([a, b]) + >>> eprod.transform(rdd).collect() + [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + """ + def __init__(self, scalingVector): + self.scalingVector = _convert_to_vector(scalingVector) + + def transform(self, vector): + """ + Computes the Hadamard product of the vector. + """ + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b18..b7f00d60069e6 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper): >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) - >>> sorted(model.freqItemsets().collect()) - [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ... """ def freqItemsets(self): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 23d1a79ffe511..e96c5ef87df86 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -36,7 +36,7 @@ import numpy as np from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ - IntegerType, ByteType + IntegerType, ByteType, BooleanType __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', @@ -163,6 +163,59 @@ def simpleString(self): return "vector" +class MatrixUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Matrix. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("numRows", IntegerType(), False), + StructField("numCols", IntegerType(), False), + StructField("colPtrs", ArrayType(IntegerType(), False), True), + StructField("rowIndices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True), + StructField("isTransposed", BooleanType(), False)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.MatrixUDT" + + def serialize(self, obj): + if isinstance(obj, SparseMatrix): + colPtrs = [int(i) for i in obj.colPtrs] + rowIndices = [int(i) for i in obj.rowIndices] + values = [float(v) for v in obj.values] + return (0, obj.numRows, obj.numCols, colPtrs, + rowIndices, values, bool(obj.isTransposed)) + elif isinstance(obj, DenseMatrix): + values = [float(v) for v in obj.values] + return (1, obj.numRows, obj.numCols, None, None, values, + bool(obj.isTransposed)) + else: + raise TypeError("cannot serialize type %r" % (type(obj))) + + def deserialize(self, datum): + assert len(datum) == 7, \ + "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseMatrix(*datum[1:]) + elif tpe == 1: + return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "matrix" + + class Vector(object): __UDT__ = VectorUDT() @@ -781,10 +834,12 @@ def zeros(size): class Matrix(object): + + __UDT__ = MatrixUDT() + """ Represents a local matrix. """ - def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 9c4647ddfdcfd..506ca2151cce7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -106,8 +106,9 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 41bde2ce3e60b..5ddbbee4babdd 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -33,12 +33,12 @@ class LabeledPoint(object): """ - The features and labels of a data point. + Class that represents the features and labels of a data point. :param label: Label for this data point. :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) Note: 'label' and 'features' are accessible as class attributes. """ @@ -59,7 +59,12 @@ def __repr__(self): class LinearModel(object): - """A linear model that has a vector of coefficients and an intercept.""" + """ + A linear model that has a vector of coefficients and an intercept. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. + """ def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) @@ -128,10 +133,11 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: - ... pass + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -193,18 +199,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, validateData=True): """ - Train a linear regression model on the given data. - - :param data: The training data. - :param iterations: The number of iterations (default: 100). + Train a linear regression model using Stochastic Gradient + Descent (SGD). + This solves the least squares regression formulation + f(weights) = 1/n ||A weights-y||^2^ + (which is the mean squared error). + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.0). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.0). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization (lasso), @@ -213,13 +229,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). (default: False) - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -232,8 +249,8 @@ def train(rdd, i): @inherit_doc class LassoModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_1 penalty term. + """A linear regression model derived from a least-squares fit with + an l_1 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -259,8 +276,9 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -304,7 +322,36 @@ class LassoWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True): - """Train a Lasso regression model on the given data.""" + """ + Train a regression model with L1-regularization using + Stochastic Gradient Descent. + This solves the l1-regularized least squares regression + formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), @@ -316,8 +363,8 @@ def train(rdd, i): @inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_2 penalty term. + """A linear regression model derived from a least-squares fit with + an l_2 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -344,8 +391,9 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -389,7 +437,36 @@ class RidgeRegressionWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True): - """Train a ridge regression model on the given data.""" + """ + Train a regression model with L2-regularization using + Stochastic Gradient Descent. + This solves the l2-regularized least squares regression + formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), @@ -400,7 +477,15 @@ def train(rdd, i): class IsotonicRegressionModel(Saveable, Loader): - """Regression model for isotonic regression. + """ + Regression model for isotonic regression. + + :param boundaries: Array of boundaries for which predictions are + known. Boundaries must be sorted in increasing order. + :param predictions: Array of predictions associated to the + boundaries at the same index. Results of isotonic + regression and therefore monotone. + :param isotonic: indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -418,8 +503,9 @@ class IsotonicRegressionModel(Saveable, Loader): 2.0 >>> sameModel.predict(5) 16.5 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -430,6 +516,25 @@ def __init__(self, boundaries, predictions, isotonic): self.isotonic = isotonic def predict(self, x): + """ + Predict labels for provided features. + Using a piecewise linear function. + 1) If x exactly matches a boundary then associated prediction + is returned. In case there are multiple predictions with the + same boundary then one of them is returned. Which one is + undefined (same as java.util.Arrays.binarySearch). + 2) If x is lower or higher than all boundaries then first or + last prediction is returned respectively. In case there are + multiple predictions with the same boundary then the lowest + or highest is returned respectively. + 3) If x falls between two values in boundary array then + prediction is treated as piecewise linear function and + interpolated value is returned. In case there are multiple + values with the same boundary then the same rules as in 2) + are used. + + :param x: Feature or RDD of Features to be labeled. + """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) return np.interp(x, self.boundaries, self.predictions) @@ -451,15 +556,15 @@ def load(cls, sc, path): class IsotonicRegression(object): - """ - Run IsotonicRegression algorithm to obtain isotonic regression model. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. - """ @classmethod def train(cls, data, isotonic=True): - """Train a isotonic regression model on the given data.""" + """ + Train a isotonic regression model on the given data. + + :param data: RDD of (label, feature, weight) tuples. + :param isotonic: Whether this is isotonic or antitonic. + """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 36a4c7a5408c6..cd80c3e07a4f7 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,8 +23,12 @@ import sys import tempfile import array as pyarray +from time import time, sleep +from shutil import rmtree -from numpy import array, array_equal, zeros, inf +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean) +from numpy import sum as array_sum from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -38,15 +42,19 @@ from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct +from pyspark.mllib.util import LinearDataGenerator from pyspark.serializers import PickleSerializer +from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext _have_scipy = False @@ -66,6 +74,20 @@ def setUp(self): self.sc = sc +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + + @staticmethod + def _ssc_wait(start_time, end_time, sleep_time): + while time() - start_time < end_time: + sleep(0.01) + + def _squared_distance(a, b): if isinstance(a, Vector): return a.squared_distance(b) @@ -379,7 +401,7 @@ def test_classification(self): self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) try: - os.removedirs(temp_dir) + rmtree(temp_dir) except OSError: pass @@ -443,6 +465,13 @@ def test_regression(self): except ValueError: self.fail() + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + class StatTests(MLlibTestCase): # SPARK-4023 @@ -507,6 +536,38 @@ def test_infer_schema(self): raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(MLlibTestCase): @@ -818,6 +879,297 @@ def test_model_transform(self): self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEquals(stkm._k, 5) + self.assertEquals(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEquals( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 10.0, 0.01) + self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offest for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + t = time() + self.ssc.start() + + # Give enough time to train the model. + self._ssc_wait(t, 6.0, 0.01) + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEquals(result, [[0], [1], [2], [3]]) + + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + t = time() + self.ssc.start() + self._ssc_wait(t, 15.0, 0.01) + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + + # Test that weights improve with a small tolerance, + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + t = time() + self.ssc.start() + self._ssc_wait(t, 5.0, 0.01) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + + # Test that the improvement in error is atleast 0.3 + self.assertTrue(errors[1] - errors[-1] > 0.3) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index cfcbea573fd22..372b86a7c95d9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -299,7 +299,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 32) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -377,7 +377,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting - features (default: 100) + features (default: 32) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -435,16 +435,17 @@ class GradientBoostedTrees(object): @classmethod def _train(cls, data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth): + loss, numIterations, learningRate, maxDepth, maxBins): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) return GradientBoostedTreesModel(model) @classmethod def trainClassifier(cls, data, categoricalFeaturesInfo, - loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ Method to train a gradient-boosted trees model for classification. @@ -467,6 +468,8 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 3) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories :return: GradientBoostedTreesModel that can be used for prediction @@ -499,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "classification", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) @classmethod def trainRegressor(cls, data, categoricalFeaturesInfo, - loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ Method to train a gradient-boosted trees model for regression. @@ -522,6 +526,8 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, contribution of each estimator. The learning rate should be between in the interval (0, 1]. (default: 0.1) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 3) @@ -556,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "regression", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 16a90db146ef0..348238319e407 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -257,6 +257,41 @@ def load(cls, sc, path): return cls(java_model) +class LinearDataGenerator(object): + """Utils for generating linear data""" + + @staticmethod + def generateLinearInput(intercept, weights, xMean, xVariance, + nPoints, seed, eps): + """ + :param: intercept bias factor, the term c in X'w + c + :param: weights feature vector, the term w in X'w + c + :param: xMean Point around which the data X is centered. + :param: xVariance Variance of the given data + :param: nPoints Number of points to be generated + :param: seed Random Seed + :param: eps Used to scale the noise. If eps is set high, + the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints + """ + weights = [float(weight) for weight in weights] + xMean = [float(mean) for mean in xMean] + xVariance = [float(var) for var in xVariance] + return list(callMLlibFunc( + "generateLinearInputWrapper", float(intercept), weights, xMean, + xVariance, int(nPoints), int(seed), float(eps))) + + @staticmethod + def generateLinearRDD(sc, nexamples, nfeatures, eps, + nParts=2, intercept=0.0): + """ + Generate a RDD of LabeledPoints. + """ + return callMLlibFunc( + "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), + float(eps), int(nParts), float(intercept)) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index d18daaabfcb3c..44d17bd629473 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -90,9 +90,11 @@ class Profiler(object): >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.parallelize(range(1000)).count() + 1000 >>> sc.show_profiles() My custom profiles for RDD:1 - My custom profiles for RDD:2 + My custom profiles for RDD:3 >>> sc.stop() """ @@ -169,4 +171,6 @@ def stats(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 98a8ff8606366..cb20bc8b54027 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -121,10 +121,22 @@ def _parse_memory(s): def _load_from_socket(port, serializer): - sock = socket.socket() - sock.settimeout(3) + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock = None + continue + break + if not sock: + raise Exception("could not open socket") try: - sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): yield item @@ -960,7 +972,7 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) def count(self): """ @@ -2198,7 +2210,7 @@ def sumApprox(self, timeout, confidence=0.95): >>> rdd = sc.parallelize(range(1000), 10) >>> r = sum(range(1000)) - >>> (rdd.sumApprox(1000) - r) / r < 0.05 + >>> abs(rdd.sumApprox(1000) - r) / r < 0.05 True """ jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() @@ -2215,7 +2227,7 @@ def meanApprox(self, timeout, confidence=0.95): >>> rdd = sc.parallelize(range(1000), 10) >>> r = sum(range(1000)) / 1000.0 - >>> (rdd.meanApprox(1000) - r) / r < 0.05 + >>> abs(rdd.meanApprox(1000) - r) / r < 0.05 True """ jrdd = self.map(float)._to_java_object_rdd() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d8cdcda3a3783..411b4dbf481f1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -44,8 +44,8 @@ >>> rdd.glom().collect() [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -8L +>>> int(rdd._jrdd.count()) +8 >>> sc.stop() """ @@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream): if size < best: batch *= 2 elif size > best * 10 and batch > 1: - batch /= 2 + batch //= 2 def __repr__(self): return "AutoBatchedSerializer(%s)" % self.serializer @@ -556,4 +556,6 @@ def write_with_length(obj, stream): if __name__ == '__main__': import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 81c420ce16541..8fb71bac64a5e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self.memory_limit + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -512,9 +512,6 @@ def load(f): f.close() chunks.append(load(open(path, 'rb'))) current_chunk = [] - gc.collect() - batch //= 2 - limit = self._next_limit() MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close @@ -841,4 +838,6 @@ def load_partition(j): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 1ecec5b126505..0a85da7443d3d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -396,6 +396,11 @@ def over(self, window): jc = self._jc.over(window._jspec) return Column(jc) + def __nonzero__(self): + raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " + "'~' for 'not' when building DataFrame boolean expressions.") + __bool__ = __nonzero__ + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 599c9ac5794a2..4dda3b430cfbf 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ + time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] + [Row(_c0=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] """ func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) @@ -202,7 +203,37 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -321,6 +352,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -329,28 +362,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d8c59518b35a..4b9efa0a210fb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,6 +22,7 @@ if sys.version >= '3': basestring = unicode = str long = int + from functools import reduce else: from itertools import imap as map @@ -193,7 +194,11 @@ def schema(self): StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ if self._schema is None: - self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + try: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + except AttributeError as e: + raise Exception( + "Unable to parse datatype from schema. %s" % e) return self._schema @since(1.3) @@ -242,9 +247,12 @@ def isLocal(self): return self._jdf.isLocal() @since(1.3) - def show(self, n=20): + def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + >>> df DataFrame[age: int, name: string] >>> df.show() @@ -255,7 +263,7 @@ def show(self, n=20): | 5| Bob| +---+-----+ """ - print(self._jdf.showString(n)) + print(self._jdf.showString(n, truncate)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -503,36 +511,52 @@ def alias(self, alias): @ignore_unicode_prefix @since(1.3) - def join(self, other, joinExprs=None, joinType=None): + def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: a string for join column name, or a join expression (Column). - If joinExprs is a string indicating the name of the join column, - the column must exist on both sides, and this performs an inner equi-join. - :param joinType: str, default 'inner'. + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] """ - if joinExprs is None: + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - elif isinstance(joinExprs, basestring): - jdf = self._jdf.join(other._jdf, joinExprs) + + if isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -729,7 +753,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] @@ -1315,6 +1339,8 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bbf465aca8d4d..45ecd826bd3bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,6 +18,7 @@ """ A collections of builtin functions """ +import math import sys if sys.version < "3": @@ -34,12 +35,15 @@ __all__ = [ 'array', 'approxCountDistinct', + 'bin', 'coalesce', 'countDistinct', 'explode', 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', + 'sha2', 'sparkPartitionId', 'struct', 'udf', @@ -143,7 +147,7 @@ def _(): 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', - 'pow': 'Returns the value of the first argument raised to the power of the second argument.' + 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } _window_functions = { @@ -230,6 +234,19 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. + + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.bin(_to_java_column(col)) + return Column(jc) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. @@ -348,6 +365,37 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -403,6 +451,26 @@ def when(condition, value): return Column(jc) +@since(1.5) +def log(arg1, arg2=None): + """Returns the first argument-based logarithm of the second argument. + + If there is only one argument, then this takes the natural logarithm of the argument. + + >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + ['0.30102', '0.69897'] + + >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() + ['0.69314', '1.60943'] + """ + sc = SparkContext._active_spark_context + if arg2 is None: + jc = sc._jvm.functions.log(_to_java_column(arg1)) + else: + jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) + return Column(jc) + + @since(1.4) def lag(col, count=1, default=None): """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f036644acc961..882a03090ec13 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -73,6 +73,13 @@ def schema(self, schema): self._jreader = self._jreader.schema(jschema) return self + @since(1.5) + def option(self, key, value): + """Adds an input option for the underlying data source. + """ + self._jreader = self._jreader.option(key, value) + return self + @since(1.4) def options(self, **options): """Adds input options for the underlying data source. @@ -218,7 +225,10 @@ def mode(self, saveMode): >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite = self._jwrite.mode(saveMode) + # At the JVM side, the default value of mode is already set to "error". + # So, if the given saveMode is None, we will not call JVM-side's mode method. + if saveMode is not None: + self._jwrite = self._jwrite.mode(saveMode) return self @since(1.4) @@ -232,6 +242,13 @@ def format(self, source): self._jwrite = self._jwrite.format(source) return self + @since(1.5) + def option(self, key, value): + """Adds an output option for the underlying data source. + """ + self._jwrite = self._jwrite.option(key, value) + return self + @since(1.4) def options(self, **options): """Adds output options for the underlying data source. @@ -257,7 +274,7 @@ def partitionBy(self, *cols): return self @since(1.4) - def save(self, path=None, format=None, mode="error", **options): + def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -272,11 +289,14 @@ def save(self, path=None, format=None, mode="error", **options): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns :param options: all other string options >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) if format is not None: self.format(format) if path is None: @@ -296,7 +316,7 @@ def insertInto(self, tableName, overwrite=False): self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) @since(1.4) - def saveAsTable(self, name, format=None, mode="error", **options): + def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the @@ -312,15 +332,18 @@ def saveAsTable(self, name, format=None, mode="error", **options): :param name: the table name :param format: the format used to save :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param partitionBy: names of partitioning columns :param options: all other string options """ self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) if format is not None: self.format(format) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode="error"): + def json(self, path, mode=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -333,10 +356,10 @@ def json(self, path, mode="error"): >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite.mode(mode).json(path) + self.mode(mode)._jwrite.json(path) @since(1.4) - def parquet(self, path, mode="error"): + def parquet(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. :param path: the path in any Hadoop supported file system @@ -346,13 +369,17 @@ def parquet(self, path, mode="error"): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite.mode(mode).parquet(path) + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) + self._jwrite.parquet(path) @since(1.4) - def jdbc(self, url, table, mode="error", properties={}): + def jdbc(self, url, table, mode=None, properties={}): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a6fce50c76c2b..34f397d0ffef0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import tempfile import pickle import functools +import time import datetime import py4j @@ -47,6 +48,20 @@ from pyspark.sql.window import Window +class UTC(datetime.tzinfo): + """UTC""" + ZERO = datetime.timedelta(0) + + def utcoffset(self, dt): + return self.ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return self.ZERO + + class ExamplePointUDT(UserDefinedType): """ User-defined type (UDT) for ExamplePoint. @@ -149,6 +164,14 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) @@ -393,7 +416,7 @@ def test_column_operators(self): self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) @@ -493,6 +516,35 @@ def test_between_function(self): self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()) + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() @@ -524,6 +576,39 @@ def test_save_and_load(self): shutil.rmtree(tmpPath) + def test_save_and_load_builder(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .option("noUse", "this option will not be used in save.")\ + .format("json").save(path=tmpPath) + actual =\ + self.sqlCtx.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -588,6 +673,23 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_time_with_timezone(self): + day = datetime.date.today() + now = datetime.datetime.now() + ts = time.mktime(now.timetuple()) + now.microsecond / 1e6 + # class in __main__ is not serializable + from pyspark.sql.tests import UTC + utc = UTC() + utcnow = datetime.datetime.fromtimestamp(ts, utc) + df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + day1, now1, utcnow1 = df.first() + # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version + self.assertEqual(day1.date(), day) + # Pyrolite does not support microsecond, the error should be + # less than 1 millisecond + self.assertTrue(now - now1 < datetime.timedelta(0.001)) + self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b6ec6137c9180..ae9344e6106a4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -19,6 +19,7 @@ import decimal import time import datetime +import calendar import keyword import warnings import json @@ -354,8 +355,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -367,8 +367,53 @@ def __init__(self, fields): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) @@ -634,7 +679,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -646,7 +691,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -654,10 +700,15 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, (DateType, TimestampType)): + return True else: return False +EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def _python_to_sql_converter(dataType): """ Returns a converter that converts a Python object into a SQL datum for the given type. @@ -682,31 +733,49 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] + return lambda a: a and [element_converter(v) for v in a] elif isinstance(dataType, MapType): key_converter = _python_to_sql_converter(dataType.keyType) value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) + return lambda obj: obj and dataType.serialize(obj) + + elif isinstance(dataType, DateType): + return lambda d: d and d.toordinal() - EPOCH_ORDINAL + + elif isinstance(dataType, TimestampType): + + def to_posix_timstamp(dt): + if dt: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e7 + dt.microsecond * 10) + return to_posix_timstamp + else: raise ValueError("Unexpected type %r" % dataType) @@ -1007,10 +1076,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ff097985fae3e..8dcb9645cdc6b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print() + print("") self.foreachRDD(takeAndPrint) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 57049beea4dba..91ce681fbe169 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import os import sys from itertools import chain @@ -677,4 +678,19 @@ def test_kafka_rdd_with_leaders(self): self._validateRddResult(sendData, rdd) if __name__ == "__main__": + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = glob.glob( + os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " + "remove all but one") % kafka_assembly_dir) + else: + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 34291f30a5652..a9bfec2aab8fc 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -125,4 +125,6 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f9fb37f7fc139..17256dfc95744 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -179,9 +179,12 @@ def test_in_memory_sort(self): list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit l = list(range(1024)) random.shuffle(l) - sorter = ExternalSorter(1) + sorter = CustomizedSorter(1) self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled @@ -458,6 +461,14 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -1410,7 +1421,8 @@ def do_termination_test(self, terminator): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) diff --git a/python/run-tests b/python/run-tests index 4468fdb3f267e..24949657ed7ab 100755 --- a/python/run-tests +++ b/python/run-tests @@ -18,165 +18,7 @@ # -# Figure out where the Spark framework is installed -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" -. "$FWDIR"/bin/load-spark-env.sh - -# CD into the python directory to find things on the right path -cd "$FWDIR/python" - -FAILED=0 -LOG_FILE=unit-tests.log -START=$(date +"%s") - -rm -f $LOG_FILE - -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL -rm -rf metastore warehouse - -function run_test() { - echo -en "Running test: $1 ... " | tee -a $LOG_FILE - start=$(date +"%s") - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 - - FAILED=$((PIPESTATUS[0]||$FAILED)) - - # Fail and exit on the first test failure. - if [[ $FAILED != 0 ]]; then - cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 - else - now=$(date +"%s") - echo "ok ($(($now - $start))s)" - fi -} - -function run_core_tests() { - echo "Run core tests ..." - run_test "pyspark.rdd" - run_test "pyspark.context" - run_test "pyspark.conf" - run_test "pyspark.broadcast" - run_test "pyspark.accumulators" - run_test "pyspark.serializers" - run_test "pyspark.profiler" - run_test "pyspark.shuffle" - run_test "pyspark.tests" -} - -function run_sql_tests() { - echo "Run sql tests ..." - run_test "pyspark.sql.types" - run_test "pyspark.sql.context" - run_test "pyspark.sql.column" - run_test "pyspark.sql.dataframe" - run_test "pyspark.sql.group" - run_test "pyspark.sql.functions" - run_test "pyspark.sql.readwriter" - run_test "pyspark.sql.window" - run_test "pyspark.sql.tests" -} - -function run_mllib_tests() { - echo "Run mllib tests ..." - run_test "pyspark.mllib.classification" - run_test "pyspark.mllib.clustering" - run_test "pyspark.mllib.evaluation" - run_test "pyspark.mllib.feature" - run_test "pyspark.mllib.fpm" - run_test "pyspark.mllib.linalg" - run_test "pyspark.mllib.random" - run_test "pyspark.mllib.recommendation" - run_test "pyspark.mllib.regression" - run_test "pyspark.mllib.stat._statistics" - run_test "pyspark.mllib.stat.KernelDensity" - run_test "pyspark.mllib.tree" - run_test "pyspark.mllib.util" - run_test "pyspark.mllib.tests" -} - -function run_ml_tests() { - echo "Run ml tests ..." - run_test "pyspark.ml.feature" - run_test "pyspark.ml.classification" - run_test "pyspark.ml.recommendation" - run_test "pyspark.ml.regression" - run_test "pyspark.ml.tuning" - run_test "pyspark.ml.tests" - run_test "pyspark.ml.evaluation" -} - -function run_streaming_tests() { - echo "Run streaming tests ..." - - KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly - JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" - for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 - echo "You need to build Spark with " \ - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ - "'build/mvn package' before running this program" 1>&2 - exit 1 - fi - KAFKA_ASSEMBLY_JAR="$f" - done - - export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark.streaming.util" - run_test "pyspark.streaming.tests" -} - -echo "Running PySpark tests. Output is in python/$LOG_FILE." - -export PYSPARK_PYTHON="python" - -# Try to test with Python 2.6, since that's the minimum version that we support: -if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" -fi - -echo "Testing with Python version:" -$PYSPARK_PYTHON --version - -run_core_tests -run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests - -# Try to test with Python 3 -if [ $(which python3.4) ]; then - export PYSPARK_PYTHON="python3.4" - echo "Testing with Python3.4 version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_mllib_tests - run_ml_tests - run_streaming_tests -fi - -# Try to test with PyPy -if [ $(which pypy) ]; then - export PYSPARK_PYTHON="pypy" - echo "Testing with PyPy version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_streaming_tests -fi - -if [[ $FAILED == 0 ]]; then - now=$(date +"%s") - echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" -fi - -# TODO: in the long-run, it would be nice to use a test runner like `nose`. -# The doctest fixtures are the current barrier to doing this. +exec python -u ./python/run-tests.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py new file mode 100755 index 0000000000000..b7737650daa54 --- /dev/null +++ b/python/run-tests.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import logging +from optparse import OptionParser +import os +import re +import subprocess +import sys +import tempfile +from threading import Thread, Lock +import time +if sys.version < '3': + import Queue +else: + import queue as Queue + + +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) + + +from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) +from sparktestsupport.shellutils import which # noqa +from sparktestsupport.modules import all_modules # noqa + + +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') + + +def print_red(text): + print('\033[31m' + text + '\033[0m') + + +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() + + +def run_individual_python_test(test_name, pyspark_python): + env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) + start_time = time.time() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) + duration = time.time() - start_time + # Exit on the first failure. + if retcode != 0: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) + per_test_output.seek(0) + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + else: + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + + +def get_default_python_executables(): + python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] + if "python2.6" not in python_execs: + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") + python_execs.insert(0, "python") + return python_execs + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "--python-executables", type="string", default=','.join(get_default_python_executables()), + help="A comma-separated list of Python executables to test against (default: %default)" + ) + parser.add_option( + "--modules", type="string", + default=",".join(sorted(python_modules.keys())), + help="A comma-separated list of Python modules to test (default: %default)" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + +def main(): + opts = parse_opts() + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + if os.path.exists(LOG_FILE): + os.remove(LOG_FILE) + python_execs = opts.python_executables.split(',') + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module %s" % module_name) + sys.exit(-1) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + + task_queue = Queue.Queue() + for python_exec in python_execs: + python_implementation = subprocess.check_output( + [python_exec, "-c", "import platform; print(platform.python_implementation())"], + universal_newlines=True).strip() + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + for test_goal in module.python_test_goals: + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) + total_duration = time.time() - start_time + LOGGER.info("Tests passed in %i seconds", total_duration) + + +if __name__ == "__main__": + main() diff --git a/repl/pom.xml b/repl/pom.xml index 85f7bc8ac1024..370b2bc2fa8ed 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -93,7 +93,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 50fd43a418bca..f150fec7db945 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -267,6 +267,17 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-8461 SQL with codegen") { + val output = runInterpreter("local", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |sqlContext.setConf("spark.sql.codegen", "true") + |sqlContext.range(0, 100).filter('id > 50).count() + """.stripMargin) + assertContains("Long = 49", output) + assertDoesNotContain("java.lang.ClassNotFoundException", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 299ff3728a6d9..1e79f4b2e88e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import java.util.Arrays; import java.util.Iterator; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import scala.Function1; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -39,16 +40,28 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final long[] emptyAggregationBuffer; + private final byte[] emptyBuffer; + + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); - private final StructType aggregationBufferSchema; + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private final boolean reuseEmptyBuffer; - private final StructType groupingKeySchema; + /** + * The projection used to initialize the emptyBuffer + */ + private final Function1 initProjection; /** - * Encodes grouping keys as UnsafeRows. + * Encodes grouping keys or buffers as UnsafeRows. */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. @@ -56,134 +69,115 @@ public final class UnsafeFixedWidthAggregationMap { private final BytesToBytesMap map; /** - * Re-used pointer to the current aggregation buffer + * An object pool for objects that are used in grouping keys. */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UniqueObjectPool keyPool; /** - * Scratch space that is used when encoding grouping keys into UnsafeRow format. - * - * By default, this is a 1MB array, but it will grow as necessary in case larger keys are - * encountered. + * An object pool for objects that are used in aggregation buffers. */ - private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; - - private final boolean enablePerfMetrics; + private final ObjectPool bufferPool; /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. + * Re-used pointer to the current aggregation buffer */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } + private final UnsafeRow currentBuffer = new UnsafeRow(); /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are + * encountered. */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } + private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; + + private final boolean enablePerfMetrics; /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - Row emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, + Function1 initProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); - this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.initProjection = initProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new long array. - */ - private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; - final long writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + InternalRow initRow = initProjection.apply(emptyRow); + this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; } /** * Return the aggregation buffer for the current group. For efficiency, all calls to this method * return the same object. */ - public UnsafeRow getAggregationBuffer(Row groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { + final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - // This new array will be initially zero, so there's no need to zero it out here - groupingKeyConversionScratchSpace = new long[groupingKeySize]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero out - // the portion of the buffer that we'll actually write to. - Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + final int actualGroupingKeySize = keyConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET); + PlatformDependent.BYTE_ARRAY_OFFSET, + keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, + bufferPool); + } loc.putNewKey( groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyAggregationBuffer, - PlatformDependent.LONG_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + emptyBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentAggregationBuffer.pointTo( + currentBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); - return currentAggregationBuffer; + return currentBuffer; } /** @@ -219,14 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), - groupingKeySchema + keyConverter.numFields(), + keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); return entry; } @@ -254,6 +248,8 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ec97fe603c44f..f077064a02ec0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,24 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.BaseMutableRow; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -47,32 +35,43 @@ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the - * base address of the row) that points to the beginning of the variable-length field. + * base address of the row) that points to the beginning of the variable-length field, and length + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. + * + * In order to support fast hashing and equality checks for UnsafeRows that contain objects + * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make + * sure all the key have the same index for same object, then we can hash/compare the objects by + * hash/compare the index. + * + * For non-primitive types, the word of a field could be: + * UNION { + * [1] [offset: 31bits] [length: 31bits] // StringType + * [0] [offset: 31bits] [length: 31bits] // BinaryType + * - [index: 63bits] // StringType, Binary, index to object in pool + * } * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends BaseMutableRow { +public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; + /** A pool to hold non-primitive objects */ + private ObjectPool pool; + Object getBaseObject() { return baseObject; } long getBaseOffset() { return baseOffset; } + ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + public int length() { return numFields; } + /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; - /** - * This optional schema is required if you want to call generic get() and set() methods on - * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() - * methods. This should be removed after the planned InternalRow / Row split; right now, it's only - * needed by the generic get() method, which is only called internally by code that accesses - * UTF8String-typed columns. - */ - @Nullable - private StructType schema; private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8L; @@ -82,38 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - /** - * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) - */ - public static final Set settableFieldTypes; - - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - static { - settableFieldTypes = Collections.unmodifiableSet( - new HashSet( - Arrays.asList(new DataType[] { - NullType, - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType - }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet( - Arrays.asList(new DataType[]{ - StringType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); - } + public static final long OFFSET_BITS = 31L; /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -127,22 +95,15 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row - * @param schema an optional schema; this is necessary if you want to call generic get() or set() - * methods on this row, but is optional if the caller will only use type-specific - * getTYPE() and setTYPE() methods. + * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, - long baseOffset, - int numFields, - @Nullable StructType schema) { + public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; - assert schema == null || schema.fields().length == numFields; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; - this.schema = schema; + this.pool = pool; } private void assertIndexIsValid(int index) { @@ -165,9 +126,68 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } + /** + * Updates the column `i` as Object `value`, which cannot be primitive types. + */ @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); + public void update(int i, Object value) { + if (value == null) { + if (!isNullAt(i)) { + // remove the old value from pool + long idx = getLong(i); + if (idx <= 0) { + // this is the index of old value in pool, remove it + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } + return; + } + + if (isNullAt(i)) { + // there is not an old value, put the new value into pool + int idx = pool.put(value); + setLong(i, (long)-idx); + } else { + // there is an old value, check the type, then replace it or update it + long v = getLong(i); + if (v <= 0) { + // it's the index in the pool, replace old value with new one + int idx = (int)-v; + pool.replace(idx, value); + } else { + // old value is UTF8String or byte[], try to reuse the space + boolean isString; + byte[] newBytes; + if (value instanceof UTF8String) { + newBytes = ((UTF8String) value).getBytes(); + isString = true; + } else { + newBytes = (byte[]) value; + isString = false; + } + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int oldLength = (int) (v & Integer.MAX_VALUE); + if (newBytes.length <= oldLength) { + // the new value can fit in the old buffer, re-use it + PlatformDependent.copyMemory( + newBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + newBytes.length); + long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; + setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); + } else { + // Cannot fit in the buffer + int idx = pool.put(value); + setLong(i, (long) -idx); + } + } + } + setNotNullAt(i); } @Override @@ -219,36 +239,43 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return numFields; } - @Override - public StructType schema() { - return schema; - } - + /** + * Returns the object for column `i`, which should not be primitive type. + */ @Override public Object get(int i) { assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling generic get() method"; - final DataType dataType = schema.fields()[i].dataType(); - // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic - // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to - // separate the internal and external row interfaces, then internal code can fetch strings via - // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; - } else if (dataType == StringType) { - return getUTF8String(i); + } + long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + if (v <= 0) { + // It's an index to object in the pool. + int idx = (int)-v; + return pool.get(idx); } else { - throw new UnsupportedOperationException(); + // The column could be StingType or BinaryType + boolean isString = (v >> (OFFSET_BITS * 2)) > 0; + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int size = (int) (v & Integer.MAX_VALUE); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + if (isString) { + return UTF8String.fromBytes(bytes); + } else { + return bytes; + } } } @@ -308,33 +335,8 @@ public double getDouble(int i) { } } - public UTF8String getUTF8String(int i) { - assertIndexIsValid(i); - final UTF8String str = new UTF8String(); - final long offsetToStringSize = getLong(i); - final int stringSizeInBytes = - (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); - final byte[] strBytes = new byte[stringSizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - stringSizeInBytes - ); - str.set(strBytes); - return str; - } - - @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - - - @Override - public Row copy() { + public InternalRow copy() { throw new UnsupportedOperationException(); } @@ -342,13 +344,4 @@ public Row copy() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); } - - @Override - public Seq toSeq() { - final ArraySeq values = new ArraySeq(numFields); - for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { - values.update(fieldNumber, get(fieldNumber)); - } - return values; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java new file mode 100644 index 0000000000000..97f89a7d0b758 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * A object pool stores a collection of objects in array, then they can be referenced by the + * pool plus an index. + */ +public class ObjectPool { + + /** + * An array to hold objects, which will grow as needed. + */ + private Object[] objects; + + /** + * How many objects in the pool. + */ + private int numObj; + + public ObjectPool(int capacity) { + objects = new Object[capacity]; + numObj = 0; + } + + /** + * Returns how many objects in the pool. + */ + public int size() { + return numObj; + } + + /** + * Returns the object at position `idx` in the array. + */ + public Object get(int idx) { + assert (idx < numObj); + return objects[idx]; + } + + /** + * Puts an object `obj` at the end of array, returns the index of it. + *

+ * The array will grow as needed. + */ + public int put(Object obj) { + if (numObj >= objects.length) { + Object[] tmp = new Object[objects.length * 2]; + System.arraycopy(objects, 0, tmp, 0, objects.length); + objects = tmp; + } + objects[numObj++] = obj; + return numObj - 1; + } + + /** + * Replaces the object at `idx` with new one `obj`. + */ + public void replace(int idx, Object obj) { + assert (idx < numObj); + objects[idx] = obj; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java new file mode 100644 index 0000000000000..d512392dcaacc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.HashMap; + +/** + * An unique object pool stores a collection of unique objects in it. + */ +public class UniqueObjectPool extends ObjectPool { + + /** + * A hash map from objects to their indexes in the array. + */ + private HashMap objIndex; + + public UniqueObjectPool(int capacity) { + super(capacity); + objIndex = new HashMap(); + } + + /** + * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will + * return the index of the existing one. + */ + @Override + public int put(Object obj) { + if (objIndex.containsKey(obj)) { + return objIndex.get(obj); + } else { + int idx = super.put(obj); + objIndex.put(obj, idx); + return idx; + } + } + + /** + * The objects can not be replaced. + */ + @Override + public void replace(int idx, Object obj) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java deleted file mode 100644 index acec2bf4520f2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import org.apache.spark.sql.catalyst.expressions.MutableRow; - -public abstract class BaseMutableRow extends BaseRow implements MutableRow { - - @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInt(int ordinal, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setLong(int ordinal, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setDouble(int ordinal, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShort(int ordinal, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setByte(int ordinal, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setFloat(int ordinal, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java deleted file mode 100644 index d138b43a3482b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import java.math.BigDecimal; -import java.sql.Date; -import java.util.List; - -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.StructType; - -public abstract class BaseRow implements Row { - - @Override - final public int length() { - return size(); - } - - @Override - public boolean anyNull() { - final int n = size(); - for (int i=0; i < n; i++) { - if (isNullAt(i)) { - return true; - } - } - return false; - } - - @Override - public StructType schema() { throw new UnsupportedOperationException(); } - - @Override - final public Object apply(int i) { - return get(i); - } - - @Override - public int getInt(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - - @Override - public Row copy() { - final int n = size(); - Object[] arr = new Object[n]; - for (int i = 0; i < n; i++) { - arr[i] = get(i); - } - return new GenericRow(arr); - } - - @Override - public Seq toSeq() { - final int n = size(); - final ArraySeq values = new ArraySeq(n); - for (int i = 0; i < n; i++) { - values.update(i, get(i)); - } - return values; - } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0d460b634d9b0..0f2fd6a86d177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -181,7 +179,7 @@ trait Row extends Serializable { def get(i: Int): Any = apply(i) /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean + def isNullAt(i: Int): Boolean = apply(i) == null /** * Returns the value at position i as a primitive boolean. @@ -189,7 +187,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean + def getBoolean(i: Int): Boolean = getAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -197,7 +195,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte + def getByte(i: Int): Byte = getAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -205,7 +203,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short + def getShort(i: Int): Short = getAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -213,7 +211,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int + def getInt(i: Int): Int = getAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -221,7 +219,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long + def getLong(i: Int): Long = getAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -230,7 +228,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float + def getFloat(i: Int): Float = getAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -238,7 +236,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double + def getDouble(i: Int): Double = getAs[Double](i) /** * Returns the value at position i as a String object. @@ -246,29 +244,35 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getString(i: Int): String + def getString(i: Int): String = getAs[String](i) /** * Returns the value at position i of decimal type as java.math.BigDecimal. * * @throws ClassCastException when data type does not match. */ - def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal] + def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) /** * Returns the value at position i of date type as java.sql.Date. * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) + + /** + * Returns the value at position i of date type as java.sql.Timestamp. + * + * @throws ClassCastException when data type does not match. + */ + def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) /** * Returns the value at position i of array type as a Scala Seq. * * @throws ClassCastException when data type does not match. */ - def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]] + def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i) /** * Returns the value at position i of array type as [[java.util.List]]. @@ -284,7 +288,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]] + def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i) /** * Returns the value at position i of array type as a [[java.util.Map]]. @@ -359,42 +363,21 @@ trait Row extends Serializable { false } - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - /* ---------------------- utility methods for Scala ---------------------- */ /** - * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq: Seq[Any] + def toSeq: Seq[Any] = { + val n = length + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } /** Displays all elements of this sequence in a string (without a separator). */ def mkString: String = toSeq.mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2e7b4c236d8f8..8f63d2120ad0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,15 +19,17 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.Date +import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable import scala.collection.mutable.HashMap +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Functions to convert Scala types to Catalyst types and vice versa. @@ -50,6 +52,13 @@ object CatalystTypeConverters { } } + private def isWholePrimitive(dt: DataType): Boolean = dt match { + case dt if isPrimitive(dt) => true + case ArrayType(elementType, _) => isWholePrimitive(elementType) + case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) + case _ => false + } + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -58,6 +67,7 @@ object CatalystTypeConverters { case structType: StructType => StructConverter(structType) case StringType => StringConverter case DateType => DateConverter + case TimestampType => TimestampConverter case dt: DecimalType => BigDecimalConverter case BooleanType => BooleanConverter case ByteType => ByteConverter @@ -103,7 +113,7 @@ object CatalystTypeConverters { /** * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. */ - final def toScala(row: Row, column: Int): ScalaOutputType = { + final def toScala(row: InternalRow, column: Int): ScalaOutputType = { if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) } @@ -123,20 +133,20 @@ object CatalystTypeConverters { * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. * This method will only be called on non-null columns. */ - protected def toScalaImpl(row: Row, column: Int): ScalaOutputType + protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType } private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: Row, column: Int): Any = row(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -145,6 +155,8 @@ object CatalystTypeConverters { private[this] val elementConverter = getConverterForType(elementType) + private[this] val isNoChange = isWholePrimitive(elementType) + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { scalaValue match { case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) @@ -163,12 +175,14 @@ object CatalystTypeConverters { override def toScala(catalystValue: Seq[Any]): Seq[Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { - catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) + catalystValue.map(elementConverter.toScala) } } - override def toScalaImpl(row: Row, column: Int): Seq[Any] = + override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = toScala(row(column).asInstanceOf[Seq[Any]]) } @@ -180,6 +194,8 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) + private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { case m: Map[_, _] => m.map { case (k, v) => @@ -200,6 +216,8 @@ object CatalystTypeConverters { override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { catalystValue.map { case (k, v) => keyConverter.toScala(k) -> valueConverter.toScala(v) @@ -207,16 +225,16 @@ object CatalystTypeConverters { } } - override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = toScala(row(column).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( - structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { + structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] { private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } - override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match { case row: Row => val ar = new Array[Any](row.size) var idx = 0 @@ -224,7 +242,7 @@ object CatalystTypeConverters { ar(idx) = converters(idx).toCatalyst(row(idx)) idx += 1 } - new GenericRowWithSchema(ar, structType) + new GenericInternalRow(ar) case p: Product => val ar = new Array[Any](structType.size) @@ -234,10 +252,10 @@ object CatalystTypeConverters { ar(idx) = converters(idx).toCatalyst(iter.next()) idx += 1 } - new GenericRowWithSchema(ar, structType) + new GenericInternalRow(ar) } - override def toScala(row: Row): Row = { + override def toScala(row: InternalRow): Row = { if (row == null) { null } else { @@ -251,27 +269,36 @@ object CatalystTypeConverters { } } - override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + override def toScalaImpl(row: InternalRow, column: Int): Row = + toScala(row(column).asInstanceOf[InternalRow]) } - private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { - case str: String => UTF8String(str) + case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 } - override def toScala(catalystValue: Any): String = catalystValue match { - case null => null - case str: String => str - case utf8: UTF8String => utf8.toString() - } - override def toScalaImpl(row: Row, column: Int): String = row(column).toString + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString + override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString } private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { - override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) override def toScala(catalystValue: Any): Date = - if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) - override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) + } + + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Timestamp): Long = + DateTimeUtils.fromJavaTimestamp(scalaValue) + override def toScala(catalystValue: Any): Timestamp = + if (catalystValue == null) null + else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: InternalRow, column: Int): Timestamp = + DateTimeUtils.toJavaTimestamp(row.getLong(column)) } private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { @@ -281,10 +308,8 @@ object CatalystTypeConverters { case d: Decimal => d } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal - override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { - case d: JavaBigDecimal => d - case d: Decimal => d.toJavaBigDecimal - } + override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = + row.get(column).asInstanceOf[Decimal].toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { @@ -293,41 +318,31 @@ object CatalystTypeConverters { } private object BooleanConverter extends PrimitiveConverter[Boolean] { - override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + override def toScalaImpl(row: InternalRow, column: Int): Boolean = row.getBoolean(column) } private object ByteConverter extends PrimitiveConverter[Byte] { - override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + override def toScalaImpl(row: InternalRow, column: Int): Byte = row.getByte(column) } private object ShortConverter extends PrimitiveConverter[Short] { - override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + override def toScalaImpl(row: InternalRow, column: Int): Short = row.getShort(column) } private object IntConverter extends PrimitiveConverter[Int] { - override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + override def toScalaImpl(row: InternalRow, column: Int): Int = row.getInt(column) } private object LongConverter extends PrimitiveConverter[Long] { - override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + override def toScalaImpl(row: InternalRow, column: Int): Long = row.getLong(column) } private object FloatConverter extends PrimitiveConverter[Float] { - override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + override def toScalaImpl(row: InternalRow, column: Int): Float = row.getFloat(column) } private object DoubleConverter extends PrimitiveConverter[Double] { - override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) - } - - /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toCatalyst(scalaValue) + override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column) } /** @@ -357,6 +372,19 @@ object CatalystTypeConverters { } } + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala + } + } + /** * Converts Scala objects to Catalyst rows / types. * @@ -367,10 +395,11 @@ object CatalystTypeConverters { def convertToCatalyst(a: Any): Any = a match { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) + case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) - case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray case m: Map[Any, Any] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap @@ -383,15 +412,6 @@ object CatalystTypeConverters { * produced by createToScalaConverter. */ def convertToScala(catalystValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toScala(catalystValue) - } - - /** - * Creates a converter function that will convert Catalyst types to Scala type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { - getConverterForType(dataType).toScala + createToScalaConverter(dataType)(catalystValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala new file mode 100644 index 0000000000000..57de0f26a9720 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * An abstract class for row used internal in Spark SQL, which only contain the columns as + * internal types. + */ +abstract class InternalRow extends Row { + + // This is only use for test + override def getString(i: Int): String = { + val str = getAs[UTF8String](i) + if (str != null) str.toString else null + } + + // These expensive API should not be used internally. + final override def getDecimal(i: Int): java.math.BigDecimal = + throw new UnsupportedOperationException + final override def getDate(i: Int): java.sql.Date = + throw new UnsupportedOperationException + final override def getTimestamp(i: Int): java.sql.Timestamp = + throw new UnsupportedOperationException + final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException + final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException + final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = + throw new UnsupportedOperationException + final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + throw new UnsupportedOperationException + final override def getStruct(i: Int): Row = throw new UnsupportedOperationException + final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException + final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = + throw new UnsupportedOperationException + + // A default implementation to change the return type + override def copy(): InternalRow = this + override def apply(i: Int): Any = get(i) + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) { + return false + } + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + while (i < length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + +object InternalRow { + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray) + + /** Returns an empty row. */ + val empty = apply() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6998cc8d9666d..21b1de1ab9cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -27,7 +28,11 @@ import org.apache.spark.sql.types._ */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + // Since we are creating a runtime mirror usign the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } /** @@ -38,7 +43,7 @@ trait ScalaReflection { val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ - val mirror: universe.Mirror + def mirror: universe.Mirror import universe._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e85312aee7d16..79f526e823cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -48,31 +49,25 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") - protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") - protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") protected val HAVING = Keyword("HAVING") - protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") protected val INSERT = Keyword("INSERT") @@ -80,13 +75,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") - protected val LAST = Keyword("LAST") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") - protected val LOWER = Keyword("LOWER") - protected val MAX = Keyword("MAX") - protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -100,26 +91,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") - protected val SQRT = Keyword("SQRT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") - protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - } - protected lazy val start: Parser[LogicalPlan] = start1 | insert | cte @@ -144,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g - .map(Aggregate(_, assignAliases(p), withFilter)) - .getOrElse(Project(assignAliases(p), withFilter)) + .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) + .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(_(withHaving)).getOrElse(withHaving) @@ -277,25 +256,37 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } - | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } - | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } - | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } + ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => + if (lexical.normalizeKeyword(udfName) == "count") { + Count(Literal(1)) + } else { + throw new AnalysisException(s"invalid expression $udfName(*)") + } + } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => + lexical.normalizeKeyword(udfName) match { + case "sum" => SumDistinct(exprs.head) + case "count" => CountDistinct(exprs) + case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") + } + } + | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp) + } else { + throw new AnalysisException(s"invalid function approximate $udfName") + } + } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp, s.toDouble) + } else { + throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") + } + } | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => @@ -304,16 +295,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } ++ elsePart casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } - ) + ) protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5883d938b676d..117c87a785fdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ +import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -73,10 +72,10 @@ class Analyzer( ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: + ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimGroupingAliases :: typeCoercionRules ++ extendedResolutionRules : _*) ) @@ -131,12 +130,38 @@ class Analyzer( } /** - * Removes no-op Alias expressions from the plan. + * Replaces [[UnresolvedAlias]]s with concrete aliases. */ - object TrimGroupingAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Aggregate(groups, aggs, child) => - Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) + object ResolveAliases extends Rule[LogicalPlan] { + private def assignAliases(exprs: Seq[NamedExpression]) = { + // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need + // to transform down the whole tree. + exprs.zipWithIndex.map { + case (u @ UnresolvedAlias(child), i) => + child match { + case _: UnresolvedAttribute => u + case ne: NamedExpression => ne + case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() + case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) + case e if !e.resolved => u + case other => Alias(other, s"_c$i")() + } + case (other, _) => other + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Aggregate(groups, aggs, child) + if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => + Aggregate(groups, assignAliases(aggs), child) + + case g: GroupingAnalytics + if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + g.withNewAggs(assignAliases(g.aggregations)) + + case Project(projectList, child) + if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + Project(assignAliases(projectList), child) } } @@ -167,49 +192,17 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). - */ - private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { - val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] - - g.bitmasks.foreach { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprs = ArrayBuffer.empty[Expression] - var bit = g.groupByExprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit) - bit -= 1 - } - - val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => - // if the input attribute in the Invalid Grouping Expression set of for this group - // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == g.gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += GroupExpression(substitution) - } - - result.toSeq - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a: Cube if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case a: Rollup if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case x: GroupingSets if x.resolved => + case a: Cube => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case a: Rollup => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case x: GroupingSets => + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() Aggregate( - x.groupByExprs :+ x.gid, + x.groupByExprs :+ VirtualColumn.groupingIdAttribute, x.aggregations, - Expand(expand(x), x.child.output :+ x.gid, x.child)) + Expand(x.bitmasks, x.groupByExprs, gid, x.child)) } } @@ -227,7 +220,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => getTable(u) @@ -247,24 +240,24 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(child = f.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateArray(args), name) if containsStar(args) => + UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateStruct(args), name) if containsStar(args) => + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) @@ -290,7 +283,7 @@ class Analyzer( val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - val (oldRelation, newRelation) = right.collect { + right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -315,38 +308,43 @@ class Analyzer( if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. - sys.error( - s""" - |Failure when resolving conflicting references in Join: - |$plan - | - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - """.stripMargin) } - - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + j + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + j.copy(right = newRight) } - j.copy(right = newRight) + + // When resolve `SortOrder`s in Sort based on child, don't report errors as + // we still have chance to resolve it based on grandchild + case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => + val newOrdering = resolveSortOrders(ordering, child, throws = false) + Sort(newOrdering, global, child) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && - resolver(nameParts(0), VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics - q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + withPosition(u) { + q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -372,6 +370,31 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + private def trimUnresolvedAlias(ne: NamedExpression) = ne match { + case UnresolvedAlias(child) => child + case other => other + } + + private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { + ordering.map { order => + // Resolve SortOrder in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + val newOrder = order transformUp { + case u @ UnresolvedAttribute(nameParts) => + plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + newOrder.asInstanceOf[SortOrder] + } catch { + case a: AnalysisException if !throws => order + } + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -382,13 +405,13 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) + val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) // If this rule was not a no-op, return the transformed plan, otherwise return the original. if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(p.output, - Sort(resolvedOrdering, global, + Sort(newOrdering, global, Project(projectList ++ missing, child))) } else { logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") @@ -396,19 +419,31 @@ class Analyzer( } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) // A small hack to create an object that will allow us to resolve any references that // refer to named expressions that are present in the grouping expressions. val groupingRelation = LocalRelation( grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + // Find sort attributes that are projected away so we can temporarily add them back in. + val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation) + + // Find aggregate expressions and evaluate them early, since they can't be evaluated in a + // Sort. + val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { + case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => + val aliased = Alias(aggOrdering.child, "_aggOrdering")() + (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) + + case other => (other, None) + }.unzip + + val missing = missingAttr ++ aliasedAggregateList.flatten if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(resolvedOrdering, global, + Sort(withAggsRemoved, global, Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. @@ -416,40 +451,25 @@ class Analyzer( } /** - * Given a child and a grandchild that are present beneath a sort operator, returns - * a resolved sort ordering and a list of attributes that are missing from the child - * but are present in the grandchild. + * Given a child and a grandchild that are present beneath a sort operator, try to resolve + * the sort ordering and returns it with a list of attributes that are missing from the + * child but are present in the grandchild. */ def resolveAndFindMissing( ordering: Seq[SortOrder], child: LogicalPlan, grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - // Find any attributes that remain unresolved in the sort. - val unresolved: Seq[Seq[String]] = - ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) - - // Create a map from name, to resolved attributes, when the desired name can be found - // prior to the projection. - val resolved: Map[Seq[String], NamedExpression] = - unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap - + val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. - val requiredAttributes = AttributeSet(resolved.values) - + val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved)) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. val missingInProject = requiredAttributes -- child.output - - // Now that we have all the attributes we need, reconstruct a resolved ordering. - // It is important to do it here, instead of waiting for the standard resolved as adding - // attributes to the project below can actually introduce ambiquity that was not present - // before. - val resolvedOrdering = ordering.map(_ transform { - case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) - }).asInstanceOf[Seq[SortOrder]] - - (resolvedOrdering, missingInProject.toSeq) + // It is important to return the new SortOrders here, instead of waiting for the standard + // resolving process as adding attributes to the project below can actually introduce + // ambiguity that was not present before. + (newOrdering, missingInProject.toSeq) } } @@ -460,8 +480,10 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + case u @ UnresolvedFunction(name, children) => + withPosition(u) { + registry.lookupFunction(name, children) + } } } } @@ -492,20 +514,21 @@ class Analyzer( object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if aggregate.resolved && containsAggregate(havingCondition) => { + if aggregate.resolved && containsAggregate(havingCondition) => + val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs Project(aggregate.output, Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } } - protected def containsAggregate(condition: Expression): Boolean = + protected def containsAggregate(condition: Expression): Boolean = { condition .collect { case ae: AggregateExpression => ae } .nonEmpty + } } /** @@ -559,23 +582,13 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) - if g.elementTypes.size > 1 && java.util.regex.Pattern.matches("_c[0-9]+", name) => { - // Assume the default name given by parser is "_c[0-9]+", - // TODO in long term, move the naming logic from Parser to Analyzer. - // In projection, Parser gave default name for TGF as does for normal UDF, - // but the TGF probably have multiple output columns/names. - // e.g. SELECT explode(map(key, value)) FROM src; - // Let's simply ignore the default given name for this case. - Some((g, Nil)) - } - case Alias(g: Generator, name) if g.elementTypes.size > 1 => + case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => // If not given the default names, and the TGF with multiple output columns failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) => Some(g, names) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 3e240fd55e250..1541491608b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.EmptyConf @@ -81,18 +85,18 @@ trait Catalog { } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new mutable.HashMap[String, LogicalPlan]() + val tables = new ConcurrentHashMap[String, LogicalPlan] override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables += ((getDbTableName(tableIdent), plan)) + tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables -= getDbTableName(tableIdent) + tables.remove(getDbTableName(tableIdent)) } override def unregisterAllTables(): Unit = { @@ -101,10 +105,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - tables.get(getDbTableName(tableIdent)) match { - case Some(_) => true - case None => false - } + tables.containsKey(getDbTableName(tableIdent)) } override def lookupRelation( @@ -112,7 +113,10 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) - val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName")) + val table = tables.get(tableFullName) + if (table == null) { + sys.error(s"Table Not Found: $tableFullName") + } val tableWithQualifiers = Subquery(tableIdent.last, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -121,9 +125,11 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.map { - case (name, _) => (name, true) - }.toSeq + val result = ArrayBuffer.empty[(String, Boolean)] + for (name <- tables.keySet()) { + result += ((name, true)) + } + result } override def refreshTable(databaseName: String, tableName: String): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c0695ae369421..a069b4710f38c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -48,17 +48,10 @@ trait CheckAnalysis { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => - if (operator.childrenResolved) { - a match { - case UnresolvedAttribute(nameParts) => - // Throw errors for specific problems with get field. - operator.resolveChildren(nameParts, resolver, throwErrors = true) - } - } - val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") @@ -103,14 +96,7 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidAggregateExpression) case _ => // Fallbacks to the following checks } @@ -136,6 +122,17 @@ trait CheckAnalysis { case _ => // Analysis successful! } + + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + } extendedCheckRules.foreach(_(plan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0849faa9bfa7b..b17457d3094c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,44 +17,50 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.Expression -import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.StringKeyHashMap + /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - type FunctionBuilder = Seq[Expression] => Expression def registerFunction(name: String, builder: FunctionBuilder): Unit + @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression - - def conf: CatalystConf } -trait OverrideFunctionRegistry extends FunctionRegistry { +class OverrideFunctionRegistry(underlying: FunctionRegistry) extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } - abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(underlying.lookupFunction(name, children)) } } -class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry { +class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders(name)(children) + val func = functionBuilders.get(name).getOrElse { + throw new AnalysisException(s"undefined function $name") + } + func(children) } } @@ -70,30 +76,117 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - - override def conf: CatalystConf = throw new UnsupportedOperationException } -/** - * Build a map with String type of key, and it also supports either key case - * sensitive or insensitive. - * TODO move this into util folder? - */ -object StringKeyHashMap { - def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { - case false => new StringKeyHashMap[T](_.toLowerCase) - case true => new StringKeyHashMap[T](identity) - } -} -class StringKeyHashMap[T](normalizer: (String) => String) { - private val base = new collection.mutable.HashMap[String, T]() +object FunctionRegistry { - def apply(key: String): T = base(normalizer(key)) + type FunctionBuilder = Seq[Expression] => Expression - def get(key: String): Option[T] = base.get(normalizer(key)) - def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) - def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator -} + val expressions: Map[String, FunctionBuilder] = Map( + // misc non-aggregate functions + expression[Abs]("abs"), + expression[CreateArray]("array"), + expression[Coalesce]("coalesce"), + expression[Explode]("explode"), + expression[If]("if"), + expression[IsNull]("isnull"), + expression[IsNotNull]("isnotnull"), + expression[Coalesce]("nvl"), + expression[Rand]("rand"), + expression[Randn]("randn"), + expression[CreateStruct]("struct"), + expression[Sqrt]("sqrt"), + + // math functions + expression[Acos]("acos"), + expression[Asin]("asin"), + expression[Atan]("atan"), + expression[Atan2]("atan2"), + expression[Bin]("bin"), + expression[Cbrt]("cbrt"), + expression[Ceil]("ceil"), + expression[Ceil]("ceiling"), + expression[Cos]("cos"), + expression[EulerNumber]("e"), + expression[Exp]("exp"), + expression[Expm1]("expm1"), + expression[Floor]("floor"), + expression[Hypot]("hypot"), + expression[Hex]("hex"), + expression[Logarithm]("log"), + expression[Log]("ln"), + expression[Log10]("log10"), + expression[Log1p]("log1p"), + expression[UnaryMinus]("negative"), + expression[Pi]("pi"), + expression[Log2]("log2"), + expression[Pow]("pow"), + expression[Pow]("power"), + expression[UnaryPositive]("positive"), + expression[Rint]("rint"), + expression[Signum]("sign"), + expression[Signum]("signum"), + expression[Sin]("sin"), + expression[Sinh]("sinh"), + expression[Tan]("tan"), + expression[Tanh]("tanh"), + expression[ToDegrees]("degrees"), + expression[ToRadians]("radians"), + + // misc functions + expression[Md5]("md5"), + expression[Sha2]("sha2"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), + + // aggregate functions + expression[Average]("avg"), + expression[Count]("count"), + expression[First]("first"), + expression[Last]("last"), + expression[Max]("max"), + expression[Min]("min"), + expression[Sum]("sum"), + + // string functions + expression[Lower]("lcase"), + expression[Lower]("lower"), + expression[StringLength]("length"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Upper]("ucase"), + expression[Upper]("upper") + ) + + val builtin: FunctionRegistry = { + val fr = new SimpleFunctionRegistry + expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) } + fr + } + /** See usage above. */ + private def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + + // See if we can find a constructor that accepts Seq[Expression] + val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption + val builder = (expressions: Seq[Expression]) => { + if (varargCtor.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + } else { + // Otherwise, find an ctor method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { + case Success(e) => + e + case Failure(e) => + throw new AnalysisException(s"Invalid number of arguments for function $name") + } + f.newInstance(expressions : _*).asInstanceOf[Expression] + } + } + (name, builder) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a42ffce0d26fa..c3d68197d64ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -58,6 +58,28 @@ object HiveTypeCoercion { case _ => None } + /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ + private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { + findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None + }) + } + + /** + * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use + * [[findTightestCommonTypeToString]] to find the TightestCommonType. + */ + private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => + findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + }) + } + + /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -91,9 +113,10 @@ trait HiveTypeCoercion { StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: + IfCoercion :: Division :: PropagateTypes :: - ExpectedInputConversion :: + AddCastForAutoCastInputTypes :: Nil /** @@ -254,15 +277,26 @@ trait HiveTypeCoercion { case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) - // we should cast all timestamp/date/string compare into string compare + // For equality between string and timestamp we cast the string to a timestamp + // so that things like rounding of subsecond precision does not affect the comparison. + case p @ Equality(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ Equality(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + + // We should cast all relative timestamp/date/string comparison into string comparisions + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case p @ BinaryComparison(left @ StringType(), right @ DateType()) => p.makeCopy(Array(left, Cast(right, StringType))) case p @ BinaryComparison(left @ DateType(), right @ StringType()) => p.makeCopy(Array(Cast(left, StringType), right)) case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) + p.makeCopy(Array(left, Cast(right, StringType))) case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) + p.makeCopy(Array(Cast(left, StringType), right)) + + // Comparisons between dates and timestamps. case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => @@ -283,8 +317,8 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) } } @@ -445,12 +479,6 @@ trait HiveTypeCoercion { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = DecimalType(max(p1, p2), max(s1, s2)) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) - if e2.dataType == DecimalType.Unlimited => - b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) - case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) - if e1.dataType == DecimalType.Unlimited => - b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -563,11 +591,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !a.resolved => - val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) - CreateArray( - children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) + case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case None => a + } // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. @@ -593,12 +622,11 @@ trait HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonType(types) match { + findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) - case None => - sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + case None => c } } } @@ -630,7 +658,7 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonType(c.valueTypes) + val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -646,11 +674,12 @@ trait HiveTypeCoercion { }.getOrElse(c) case c: CaseKeyWhen if c.childrenResolved && !c.resolved => - val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + val maybeCommonType = + findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) + case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => + Seq(Cast(whenExpr, commonType), thenExpr) case other => other }.reduce(_ ++ _) CaseKeyWhen(Cast(c.key, commonType), castedBranches) @@ -658,17 +687,37 @@ trait HiveTypeCoercion { } } + /** + * Coerces the type of different branches of If statement to a common type. + */ + object IfCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Find tightest common type for If, if the true value and false value have different types. + case i @ If(pred, left, right) if left.dataType != right.dataType => + findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) + }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. + + // Convert If(null literal, _, _) into boolean type. + // In the optimizer, we should short-circuit this directly into false value. + case If(pred, left, right) if pred.dataType == NullType => + If(Literal.create(null, BooleanType), left, right) + } + } + /** * Casts types according to the expected input types for Expressions that have the trait - * `ExpectsInputTypes`. + * [[AutoCastInputTypes]]. */ - object ExpectedInputConversion extends Rule[LogicalPlan] { + object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index bbb150c1e83c7..ae3adbab05108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -67,7 +67,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" @@ -85,7 +85,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -107,7 +107,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override lazy val resolved = false // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] @@ -166,7 +166,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child AS $names" @@ -200,8 +200,27 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child[$extraction]" } + +/** + * Holds the expression that has yet to be aliased. + */ +case class UnresolvedAlias(child: Expression) extends NamedExpression + with trees.UnaryNode[Expression] { + + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def name: String = throw new UnresolvedException(this, "name") + + override lazy val resolved = false + + override def eval(input: InternalRow = null): Any = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index fcadf9595e768..5db2fcfcb267b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types._ /** * A bound reference points to a specific slot in the input tuple, allowing the actual value @@ -33,7 +33,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal]" - override def eval(input: Row): Any = input(ordinal) + override def eval(input: InternalRow): Any = input(ordinal) override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 18102d1acb5b3..d69d490ad666a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,18 +17,27 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} -import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) + override def checkInputDataTypes(): TypeCheckResult = { + if (resolve(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") + } + } override def foldable: Boolean = child.foldable @@ -111,10 +120,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) - case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) - case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) - case _ => buildCast[Any](_, o => UTF8String(o.toString)) + case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) + case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) + case TimestampType => buildCast[Long](_, + t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } // BinaryConverter @@ -127,7 +137,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case StringType => buildCast[UTF8String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + buildCast[Long](_, t => t != 0) case DateType => // Hive would return null when cast from date to boolean buildCast[Int](_, d => null) @@ -140,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != 0) + buildCast[Decimal](_, _ != Decimal(0)) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -158,20 +168,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + try DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(n)) + catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0)) + buildCast[Boolean](_, b => (if (b) 1L else 0)) case LongType => - buildCast[Long](_, l => new Timestamp(l)) + buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i)) + buildCast[Int](_, i => longToTimestamp(i.toLong)) case ShortType => - buildCast[Short](_, s => new Timestamp(s)) + buildCast[Short](_, s => longToTimestamp(s.toLong)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b)) + buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -191,50 +202,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w }) } - private[this] def decimalToTimestamp(d: Decimal) = { - val seconds = Math.floor(d.toDouble).toLong - val bd = (d.toBigDecimal - seconds) * 1000000000 - val nanos = bd.intValue() - - val millis = seconds * 1000 - val t = new Timestamp(millis) - - // remaining fractional portion as nanos - t.setNanos(nanos) - t + private[this] def decimalToTimestamp(d: Decimal): Long = { + (d.toBigDecimal * 10000000L).longValue() } - // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong - - private[this] def timestampToDouble(ts: Timestamp) = { - // First part is the seconds since the beginning of time, followed by nanosecs. - Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 - } - - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToString(ts: Timestamp): String = { - val timestampString = ts.toString - val formatted = Cast.threadLocalTimestampFormat.get.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { - formatted - } + // converting milliseconds to 100ns + private[this] def longToTimestamp(t: Long): Long = t * 10000L + // converting 100ns to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong + // converting 100ns to seconds in double + private[this] def timestampToDouble(ts: Long): Double = { + ts / 10000000.0 } // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s.toString)) + try DateTimeUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) + buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 10000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null @@ -253,7 +244,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t)) + buildCast[Long](_, t => timestampToLong(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -269,7 +260,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toInt) + buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -285,7 +276,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toShort) + buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -301,7 +292,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toByte) + buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -324,7 +315,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { - changePrecision(Decimal(s.toString.toDouble), target) + changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) } catch { case _: NumberFormatException => null }) @@ -334,7 +325,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case DecimalType() => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) case LongType => @@ -358,7 +349,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t)) + buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -374,7 +365,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + buildCast[Long](_, t => timestampToDouble(t).toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -398,7 +389,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } // TODO: Could be faster? val newRow = new GenericMutableRow(from.fields.size) - buildCast[Row](_, row => { + buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { val v = row(i) @@ -430,7 +421,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } @@ -441,17 +432,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => - s"new ${ctx.stringType}().set($c)") + s"${ctx.stringType}.fromBytes($c)") case (DateType, StringType) => defineCodeGen(ctx, ev, c => - s"""new ${ctx.stringType}().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => - super.genCode(ctx, ev) + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => - defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") // fallback for DecimalType, this must be before other numeric types case (_, dt: DecimalType) => @@ -460,14 +451,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c.isZero()") + defineCodeGen(ctx, ev, c => s"!$c.isZero()") case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") - - case (_: DecimalType, IntegerType) => - defineCodeGen(ctx, ev, c => s"($c).toInt()") case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") @@ -476,19 +464,3 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } } - -object Cast { - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - } - } - - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a9a9c0cfb7027..e5dc7b9b5c884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,11 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ + +/** + * If an expression wants to be exposed in the function registry (so users can call it with + * "name(arguments...)", the concrete implementation must be a case class whose constructor + * arguments are all Expressions types. + * + * See [[Substring]] for an example. + */ abstract class Expression extends TreeNode[Expression] { self: Product => @@ -50,7 +58,15 @@ abstract class Expression extends TreeNode[Expression] { def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): Any + def eval(input: InternalRow = null): Any + + /** + * Return true if this expression is thread-safe, which means it could be used by multiple + * threads in the same time. + * + * An expression that is not thread-safe can not be cached and re-used, especially for codegen. + */ + def isThreadSafe: Boolean = true /** * Returns an [[GeneratedExpressionCode]], which contains Java source code that @@ -60,6 +76,9 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + if (!isThreadSafe) { + throw new Exception(s"$this is not thread-safe, can not be used in codegen") + } val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) @@ -127,20 +146,23 @@ abstract class Expression extends TreeNode[Expression] { * cosmetically (i.e. capitalization of names in attributes may be different). */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (i1, i2) => i1 == i2 - } + checkSemantic(elements1, elements2) } /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. - * Note: it's not valid to call this method until `childrenResolved == true` - * TODO: we should remove the default implementation and implement it for all - * expressions with proper error message. + * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } @@ -156,6 +178,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" + override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe /** * Short hand for generating binary evaluation code, which depends on two sub-evaluations of * the same type. If either of the sub-expressions is null, the result of this computation @@ -203,6 +226,10 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def isThreadSafe: Boolean = child.isThreadSafe + /** * Called by unary expressions to generate a code block that returns null if its parent returns * null, and if not not null, use `f` to generate the expression. @@ -230,23 +257,11 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } -// TODO Semantically we probably not need GroupExpression -// All we need is holding the Seq[Expression], and ONLY used in doing the -// expressions transformation correctly. Probably will be removed since it's -// not like a real expressions. -case class GroupExpression(children: Seq[Expression]) extends Expression { - self: Product => - override def eval(input: Row): Any = throw new UnsupportedOperationException - override def nullable: Boolean = false - override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException -} - /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait ExpectsInputTypes { +trait AutoCastInputTypes { self: Expression => def expectedChildTypes: Seq[DataType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index a1e0819e8a433..4d7c95ffd1850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -41,16 +41,22 @@ object ExtractValue { resolver: Resolver): ExtractValue = { (child.dataType, extraction) match { - case (StructType(fields), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetStructField(child, fields(ordinal), ordinal) - case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) + case (StructType(fields), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) + case (_: MapType, _) => GetMapValue(child, extraction) + case (otherType, _) => val errorMsg = otherType match { case StructType(_) | ArrayType(StructType(_), _) => @@ -90,23 +96,33 @@ object ExtractValue { } } +/** + * A common interface of all kinds of extract value expressions. + * Note: concrete extract value expressions are created only by `ExtractValue.apply`, + * we don't need to do type check for them. + */ trait ExtractValue extends UnaryExpression { self: Product => } +abstract class ExtractValueWithStruct extends ExtractValue { + self: Product => + + def field: StructField + override def toString: String = s"$child.${field.name}" +} + /** * Returns the value of fields in the Struct `child`. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends ExtractValue { + extends ExtractValueWithStruct { override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Row] + override def eval(input: InternalRow): Any = { + val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } } @@ -118,15 +134,12 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, - containsNull: Boolean) extends ExtractValue { + containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + override def eval(input: InternalRow): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] if (baseValue == null) null else { baseValue.map { row => if (row == null) null else row(ordinal) @@ -146,7 +159,7 @@ abstract class ExtractValueWithOrdinal extends ExtractValue { override def toString: String = s"$child[$ordinal]" override def children: Seq[Expression] = child :: ordinal :: Nil - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { null @@ -171,14 +184,11 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - override lazy val resolved = childrenResolved && - child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] - protected def evalNotNull(value: Any, ordinal: Any) = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] - val index = ordinal.asInstanceOf[Int] + val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.size || index < 0) { null } else { @@ -195,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] - protected def evalNotNull(value: Any, ordinal: Any) = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 8cae548279eb1..fcfe83ceb863a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions - /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the @@ -30,14 +29,14 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null - def apply(input: Row): Row = { + def apply(input: InternalRow): InternalRow = { val outputArray = new Array[Any](exprArray.length) var i = 0 while (i < exprArray.length) { outputArray(i) = exprArray(i).eval(input) i += 1 } - new GenericRow(outputArray) + new GenericInternalRow(outputArray) } override def toString: String = s"Row => [${exprArray.mkString(",")}]" @@ -55,14 +54,14 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) - def currentValue: Row = mutableRow + def currentValue: InternalRow = mutableRow override def target(row: MutableRow): MutableProjection = { mutableRow = row this } - override def apply(input: Row): Row = { + override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,31 +75,31 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. */ -class JoinedRow extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -136,13 +135,7 @@ class JoinedRow extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -150,7 +143,7 @@ class JoinedRow extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -176,31 +169,31 @@ class JoinedRow extends Row { * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds * crazy but in benchmarks it had noticeable effects. */ -class JoinedRow2 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow2 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -236,13 +229,7 @@ class JoinedRow2 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -250,7 +237,7 @@ class JoinedRow2 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -270,31 +257,31 @@ class JoinedRow2 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow3 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow3 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -330,13 +317,7 @@ class JoinedRow3 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -344,7 +325,7 @@ class JoinedRow3 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -364,31 +345,31 @@ class JoinedRow3 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow4 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow4 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -424,13 +405,7 @@ class JoinedRow4 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -438,7 +413,7 @@ class JoinedRow4 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -458,31 +433,31 @@ class JoinedRow4 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow5 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow5 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -518,13 +493,7 @@ class JoinedRow5 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -532,7 +501,7 @@ class JoinedRow5 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -552,31 +521,31 @@ class JoinedRow5 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow6 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow6 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -612,13 +581,7 @@ class JoinedRow6 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -626,7 +589,7 @@ class JoinedRow6 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5b45347872cca..dbb4381d54c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) +case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { override def nullable: Boolean = true @@ -38,14 +38,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => val func = function.asInstanceOf[($anys) => Any] $childs $converters - (input: Row) => { + (input: InternalRow) => { func( $evals) } @@ -57,7 +57,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] - (input: Row) => { + (input: InternalRow) => { func() } @@ -65,7 +65,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input))) } @@ -76,7 +76,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child1 = children(1) lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input))) @@ -90,7 +90,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -107,7 +107,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -127,7 +127,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -150,7 +150,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -176,7 +176,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -205,7 +205,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -237,7 +237,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -272,7 +272,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -310,7 +310,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -351,7 +351,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -395,7 +395,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -442,7 +442,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -492,7 +492,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -545,7 +545,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -601,7 +601,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -660,7 +660,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -722,7 +722,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -787,7 +787,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -855,7 +855,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -926,7 +926,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -955,6 +955,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:on private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - override def eval(input: Row): Any = converter(f(input)) + override def eval(input: InternalRow): Any = converter(f(input)) + // TODO(davies): make ScalaUDF work with codegen + override def isThreadSafe: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 99340a14c9ecc..4baae03b3a224 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -36,7 +36,7 @@ case class SortOrder(child: Expression, direction: SortDirection) extends Expres override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf9..3928c0f2ffdaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A parent class for mutable container objects that are reused when the values are changed, @@ -195,14 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this(dataTypes: Seq[DataType]) = this( dataTypes.map { - case IntegerType => new MutableInt + case BooleanType => new MutableBoolean case ByteType => new MutableByte - case FloatType => new MutableFloat case ShortType => new MutableShort + // We use INT for DATE internally + case IntegerType | DateType => new MutableInt + // We use Long for Timestamp internally + case LongType | TimestampType => new MutableLong + case FloatType => new MutableFloat case DoubleType => new MutableDouble - case BooleanType => new MutableBoolean - case LongType => new MutableLong - case DateType => new MutableInt // We use INT for DATE internally case _ => new MutableAny }.toArray) @@ -220,7 +222,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): Row = { + override def copy(): InternalRow = { val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { @@ -228,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR i += 1 } - new GenericRow(newValues) + new GenericInternalRow(newValues) } override def update(ordinal: Int, value: Any) { @@ -239,7 +241,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) + override def setString(ordinal: Int, value: String): Unit = + update(ordinal, UTF8String.fromString(value)) override def getString(ordinal: Int): String = apply(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 5b2c8572784bd..b11fc245c4af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String /** * Converts Rows into UnsafeRow format. This class is NOT thread-safe. @@ -32,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } + def numFields: Int = fieldTypes.length + /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -47,7 +51,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Compute the amount of space, in bytes, required to encode the given row. */ - def getSizeRequirement(row: Row): Int = { + def getSizeRequirement(row: InternalRow): Int = { var fieldNumber = 0 var variableLengthFieldSize: Int = 0 while (fieldNumber < writers.length) { @@ -67,19 +71,32 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param baseOffset the base offset of the destination address * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + + if (writers.length > 0) { + // zero-out the bitset + var n = writers.length / 64 + while (n >= 0) { + PlatformDependent.UNSAFE.putLong( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset + n * 8, + 0L) + n -= 1 + } + } + var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize + var cursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) } fieldNumber += 1 } - appendCursor + cursor } } @@ -94,16 +111,16 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * @param cursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. */ - def getSize(source: Row, column: Int): Int + def getSize(source: InternalRow, column: Int): Int } private object UnsafeColumnWriter { @@ -114,13 +131,13 @@ private object UnsafeColumnWriter { case BooleanType => BooleanUnsafeColumnWriter case ByteType => ByteUnsafeColumnWriter case ShortType => ShortUnsafeColumnWriter - case IntegerType => IntUnsafeColumnWriter - case LongType => LongUnsafeColumnWriter + case IntegerType | DateType => IntUnsafeColumnWriter + case LongType | TimestampType => LongUnsafeColumnWriter case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + case BinaryType => BinaryUnsafeColumnWriter + case t => ObjectUnsafeColumnWriter } } } @@ -136,88 +153,122 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter +private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: - def getSize(sourceRow: Row, column: Int): Int = 0 + def getSize(sourceRow: InternalRow, column: Int): Int = 0 } private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(source: Row, column: Int): Int = { - val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) +private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { + + def getBytes(source: InternalRow, column: Int): Array[Byte] + + def getSize(source: InternalRow, column: Int): Int = { + val numBytes = getBytes(source, column).length + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String] - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + protected[this] def isString: Boolean + + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor + val bytes = getBytes(source, column) + val numBytes = bytes.length + if ((numBytes & 0x07) > 0) { + // zero-out the padding bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) + } PlatformDependent.copyMemory( - value.getBytes, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor + 8, + target.getBaseObject, + offset, numBytes ) - target.setLong(column, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 + target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} + +private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = true + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[UTF8String](column).getBytes + } +} + +private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = false + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) + } +} + +private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(sourceRow: InternalRow, column: Int): Int = 0 + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val obj = source.get(column) + val idx = target.getPool.put(obj) + target.setLong(column, - idx) + 0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 0266084a6d174..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet abstract class AggregateExpression extends Expression { @@ -37,7 +39,7 @@ abstract class AggregateExpression extends Expression { * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are * replaced with a physical aggregate operator at runtime. */ - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -80,7 +82,7 @@ abstract class AggregateFunction override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType - def update(input: Row): Unit + def update(input: InternalRow): Unit // Do we really need this? override def newInstance(): AggregateFunction = { @@ -100,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MinFunction = new MinFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -108,7 +113,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) val cmp = GreaterThan(currentMin, expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (currentMin.value == null) { currentMin.value = expr.eval(input) } else if (cmp.eval(input) == true) { @@ -116,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr } } - override def eval(input: Row): Any = currentMin.value + override def eval(input: InternalRow): Any = currentMin.value } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -131,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MaxFunction = new MaxFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -139,7 +147,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) val cmp = LessThan(currentMax, expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (currentMax.value == null) { currentMax.value = expr.eval(input) } else if (cmp.eval(input) == true) { @@ -147,7 +155,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr } } - override def eval(input: Row): Any = currentMax.value + override def eval(input: InternalRow): Any = currentMax.value } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -164,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Long = _ + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1L + } + } + + override def eval(input: InternalRow): Any = count +} + case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) @@ -182,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate } } +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen.size.toLong +} + case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) @@ -205,14 +250,14 @@ case class CollectHashSetFunction( @transient val distinctValue = new InterpretedProjection(expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = distinctValue(input) if (!evaluatedExpr.anyNull) { seen.add(evaluatedExpr) } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { seen } } @@ -238,7 +283,7 @@ case class CombineSetsAndCountFunction( val seen = new OpenHashSet[Any]() - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { @@ -246,7 +291,7 @@ case class CombineSetsAndCountFunction( } } - override def eval(input: Row): Any = seen.size.toLong + override def eval(input: InternalRow): Any = seen.size.toLong } /** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ @@ -277,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = hyperLogLog +} + case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { @@ -288,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: InternalRow): Any = hyperLogLog.cardinality() +} + case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -348,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } override def newInstance(): AverageFunction = new AverageFunction(child, this) -} - -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) -} - -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: Row): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: Row): Any = { - val casted = seen.asInstanceOf[OpenHashSet[Row]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -524,7 +455,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private def addFunction(value: Any) = Add(sum, Cast(Literal.create(value, expr.dataType), calcType)) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { if (count == 0L) { null } else { @@ -541,7 +472,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { count += 1 @@ -550,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - var count: Long = _ + override def nullable: Boolean = true - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType } - override def eval(input: Row): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. + override def toString: String = s"SUM($child)" - private val hyperLogLog = new HyperLogLog(relativeSD) + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(CombineSum(partialSum.toAttribute), dataType), + partialSum :: Nil) - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + CombineSum(partialSum.toAttribute), + partialSum :: Nil) } } - override def eval(input: Row): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } + override def newInstance(): SumFunction = new SumFunction(child, this) - override def eval(input: Row): Any = hyperLogLog.cardinality() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -619,11 +536,11 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { sum.update(addFunction, input) } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { expr.dataType match { case DecimalType.Fixed(_, _) => Cast(sum, dataType).eval(null) @@ -632,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) +} + case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -652,7 +593,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val result = expr.eval(input) // partial sum result can be null only when no input rows present if(result != null) { @@ -660,7 +601,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { expr.dataType match { case DecimalType.Fixed(_, _) => Cast(sum, dataType).eval(null) @@ -669,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } +case class SumDistinct(child: Expression) + extends PartialAggregate with trees.UnaryNode[Expression] { + + def this() = this(null) + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -676,14 +644,14 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) private val seen = new scala.collection.mutable.HashSet[Any]() - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { seen += evaluatedExpr } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { if (seen.size == 0) { null } else { @@ -695,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CountDistinctFunction( - @transient expr: Seq[Expression], +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, @transient base: AggregateExpression) extends AggregateFunction { @@ -704,17 +684,39 @@ case class CountDistinctFunction( val seen = new OpenHashSet[Any]() - @transient - val distinctValue = new InterpretedProjection(expr) + override def update(input: InternalRow): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } - override def update(input: Row): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) } } +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" - override def eval(input: Row): Any = seen.size.toLong + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -722,13 +724,28 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag var result: Any = null - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (result == null) { result = expr.eval(input) } } - override def eval(input: Row): Any = result + override def eval(input: InternalRow): Any = result +} + +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance(): LastFunction = new LastFunction(child, this) } case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -736,11 +753,11 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg var result: Any = null - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { result = input } - override def eval(input: Row): Any = { - if (result != null) expr.eval(result.asInstanceOf[Row]) else null + override def eval(input: InternalRow): Any = { + if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 124274c94203c..3d4d9e2d798f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,18 +18,16 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ abstract class UnaryArithmetic extends UnaryExpression { self: Product => - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null @@ -51,43 +49,20 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") - case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") } protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } -case class Sqrt(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = DoubleType - override def nullable: Boolean = true - override def toString: String = s"SQRT($child)" - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") +case class UnaryPositive(child: Expression) extends UnaryArithmetic { + override def toString: String = s"positive($child)" - private lazy val numeric = TypeUtils.getNumeric(child.dataType) - - protected override def evalInternal(evalE: Any) = { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + defineCodeGen(ctx, ev, c => c) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - if (${eval.primitive} < 0.0) { - ${ev.isNull} = true; - } else { - ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); - } - } - """ - } + protected override def evalInternal(evalE: Any) = evalE } /** @@ -124,7 +99,7 @@ abstract class BinaryArithmetic extends BinaryExpression { protected def checkTypesInternal(t: DataType): TypeCheckResult - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { null @@ -143,8 +118,8 @@ abstract class BinaryArithmetic extends BinaryExpression { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => - defineCodeGen(ctx, ev, (eval1, eval2) => - s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -204,7 +179,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def decimalMethod: String = "$divide" + override def decimalMethod: String = "$div" override def nullable: Boolean = true @@ -219,7 +194,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { null @@ -244,11 +219,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -256,7 +228,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } @@ -264,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def decimalMethod: String = "reminder" + override def decimalMethod: String = "remainder" override def nullable: Boolean = true @@ -279,7 +251,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { null @@ -304,11 +276,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -316,7 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } @@ -330,7 +299,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val ordering = TypeUtils.getOrdering(dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) if (evalE1 == null) { @@ -347,31 +316,29 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isNativeType(left.dataType)) { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - eval1.code + eval2.code + s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = - ${ctx.defaultValue(left.dataType)}; - - if (${eval1.isNull}) { - ${ev.isNull} = ${eval2.isNull}; - ${ev.primitive} = ${eval2.primitive}; - } else if (${eval2.isNull}) { - ${ev.isNull} = ${eval1.isNull}; + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if ($compCode > 0) { ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitive} > ${eval2.primitive}) { - ${ev.primitive} = ${eval1.primitive}; - } else { - ${ev.primitive} = ${eval2.primitive}; - } + ${ev.primitive} = ${eval2.primitive}; } - """ - } else { - super.genCode(ctx, ev) - } + } + """ } override def toString: String = s"MaxOf($left, $right)" } @@ -384,7 +351,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val ordering = TypeUtils.getOrdering(dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) if (evalE1 == null) { @@ -401,33 +368,29 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isNativeType(left.dataType)) { - - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - - eval1.code + eval2.code + s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = - ${ctx.defaultValue(left.dataType)}; + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) - if (${eval1.isNull}) { - ${ev.isNull} = ${eval2.isNull}; - ${ev.primitive} = ${eval2.primitive}; - } else if (${eval2.isNull}) { - ${ev.isNull} = ${eval1.isNull}; + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if ($compCode < 0) { ${ev.primitive} = ${eval1.primitive}; } else { - if (${eval1.primitive} < ${eval2.primitive}) { - ${ev.primitive} = ${eval1.primitive}; - } else { - ${ev.primitive} = ${eval2.primitive}; - } + ${ev.primitive} = ${eval2.primitive}; } - """ - } else { - super.genCode(ctx, ev) - } + } + """ } override def toString: String = s"MinOf($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e95682f952a7b..57e0bede5db20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -26,13 +26,15 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + // These classes are here to avoid issues with serialization and integration with quasiquotes. class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] /** - * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * * @param code The sequence of statements required to evaluate the expression. * @param isNull A term that holds a boolean value representing whether the expression evaluated @@ -57,6 +59,14 @@ class CodeGenContext { val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + private val curId = new java.util.concurrent.atomic.AtomicInteger() /** @@ -70,116 +80,128 @@ class CodeGenContext { } /** - * Return the code to access a column for given DataType + * Returns the code to access a column in Row for a given DataType. */ def getColumn(dataType: DataType, ordinal: Int): String = { - if (isNativeType(dataType)) { - s"i.${accessorForType(dataType)}($ordinal)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"i.get${primitiveTypeName(jt)}($ordinal)" } else { - s"(${boxedType(dataType)})i.apply($ordinal)" + s"($jt)i.apply($ordinal)" } } /** - * Return the code to update a column in Row for given DataType + * Returns the code to update a column in Row for a given DataType. */ def setColumn(dataType: DataType, ordinal: Int, value: String): String = { - if (isNativeType(dataType)) { - s"${mutatorForType(dataType)}($ordinal, $value)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"set${primitiveTypeName(jt)}($ordinal, $value)" } else { s"update($ordinal, $value)" } } /** - * Return the name of accessor in Row for a DataType + * Returns the name used in accessor and setter for a Java primitive type. */ - def accessorForType(dt: DataType): String = dt match { - case IntegerType => "getInt" - case other => s"get${boxedType(dt)}" + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) } - /** - * Return the name of mutator in Row for a DataType - */ - def mutatorForType(dt: DataType): String = dt match { - case IntegerType => "setInt" - case other => s"set${boxedType(dt)}" - } + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) /** - * Return the Java type for a DataType + * Returns the Java type for a DataType. */ def javaType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => "int" - case TimestampType => "java.sql.Timestamp" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" } /** - * Return the boxed type in Java + * Returns the boxed type in Java. */ - def boxedType(dt: DataType): String = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case DateType => "Integer" - case _ => javaType(dt) + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other } + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + /** - * Return the representation of default value for given DataType + * Returns the representation of default value for a given Java Type. */ - def defaultValue(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "(short)-1" - case LongType => "-1L" - case ByteType => "(byte)-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" + def defaultValue(jt: String): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" case _ => "null" } + def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) + + /** + * Generates code for equal expression in Java. + */ + def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { + case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case other => s"$c1.equals($c2)" + } + /** - * Returns a function to generate equal expression in Java + * Generates code for compare expression in Java. */ - def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { - case BinaryType => { case (eval1, eval2) => - s"java.util.Arrays.equals($eval1, $eval2)" } - case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => - { case (eval1, eval2) => s"$eval1 == $eval2" } - case other => - { case (eval1, eval2) => s"$eval1.equals($eval2)" } + def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // java boolean doesn't support > or < operator + case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + // use c1 - c2 may overflow + case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case other => s"$c1.compare($c2)" } /** - * List of data types that have special accessors and setters in [[Row]]. + * List of java data types that have special accessors and setters in [[InternalRow]]. */ - val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) /** - * Returns true if the data type has a special accessor and setter in [[Row]]. + * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. */ - def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) +} + + +abstract class GeneratedClass { + def generate(expressions: Array[Expression]): Any } /** @@ -193,11 +215,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - /** - * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. - */ - var debugLogging = false - /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -218,10 +235,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * It will track the time used to compile */ - protected def compile(code: String): Class[_] = { + protected def compile(code: String): GeneratedClass = { val startTime = System.nanoTime() - val clazz = try { - new ClassBodyEvaluator(code).getClazz() + val evaluator = new ClassBodyEvaluator() + evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setExtendedClass(classOf[GeneratedClass]) + try { + evaluator.cook(code) } catch { case e: Exception => logError(s"failed to compile:\n $code", e) @@ -230,7 +251,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") - clazz + evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } /** @@ -243,7 +264,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * weak keys/values and thus does not respond to memory pressure. */ protected val cache = CacheBuilder.newBuilder() - .maximumSize(1000) + .maximumSize(100) .build( new CacheLoader[InType, OutType]() { override def load(in: InType): OutType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e5ee2accd8a84..64ef357a4f954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ // MutableProjection is not accessible in Java -abstract class BaseMutableProjection extends MutableProjection {} +abstract class BaseMutableProjection extends MutableProjection /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new - * input [[Row]] for a fixed set of [[Expression Expressions]]. + * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { @@ -47,9 +47,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ }.mkString("\n") val code = s""" - import org.apache.spark.sql.Row; - - public SpecificProjection generate($exprType[] expr) { + public Object generate($exprType[] expr) { return new SpecificProjection(expr); } @@ -69,12 +67,12 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } /* Provide immutable access to the last projected row. */ - public Row currentValue() { - return mutableRow; + public InternalRow currentValue() { + return (InternalRow) mutableRow; } public Object apply(Object _i) { - Row i = (Row) _i; + InternalRow i = (InternalRow) _i; $projectionCode return mutableRow; @@ -82,14 +80,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) () => { - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 36e155d164a40..7ed2c5addec9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -21,14 +21,13 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Private import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, NumericType} /** * Inherits some default implementation for Java from `Ordering[Row]` */ @Private -class BaseOrdering extends Ordering[Row] { - def compare(a: Row, b: Row): Int = { +class BaseOrdering extends Ordering[InternalRow] { + def compare(a: InternalRow, b: InternalRow): Int = { throw new UnsupportedOperationException } } @@ -37,7 +36,8 @@ class BaseOrdering extends Ordering[Row] { * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ -object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { +object GenerateOrdering + extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = @@ -46,7 +46,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) - protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { + protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val a = newTermName("a") val b = newTermName("b") val ctx = newCodeGenContext() @@ -55,39 +55,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val evalA = order.child.gen(ctx) val evalB = order.child.gen(ctx) val asc = order.direction == Ascending - val compare = order.child.dataType match { - case BinaryType => - s""" - { - byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; - byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; - int j = 0; - while (j < x.length && j < y.length) { - if (x[j] != y[j]) return x[j] - y[j]; - j = j + 1; - } - int d = x.length - y.length; - if (d != 0) { - return d; - } - }""" - case _: NumericType => - s""" - if (${evalA.primitive} != ${evalB.primitive}) { - if (${evalA.primitive} > ${evalB.primitive}) { - return ${if (asc) "1" else "-1"}; - } else { - return ${if (asc) "-1" else "1"}; - } - }""" - case _ => - s""" - int comp = ${evalA.primitive}.compare(${evalB.primitive}); - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - }""" - } - s""" i = $a; ${evalA.code} @@ -100,14 +67,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - $compare + int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + } } """ }.mkString("\n") val code = s""" - import org.apache.spark.sql.Row; - public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); } @@ -121,8 +89,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } @Override - public int compare(Row a, Row b) { - Row i = null; // Holds current row being evaluated. + public int compare(InternalRow a, InternalRow b) { + InternalRow i = null; // Holds current row being evaluated. $comparisons return 0; } @@ -130,9 +98,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit logDebug(s"Generated Ordering: $code") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] + compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 4a547b5ce9543..3ebc2c147579b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -23,25 +23,23 @@ import org.apache.spark.sql.catalyst.expressions._ * Interface for generated predicate */ abstract class Predicate { - def eval(r: Row): Boolean + def eval(r: InternalRow): Boolean } /** - * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { +object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((Row) => Boolean) = { + protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) val code = s""" - import org.apache.spark.sql.Row; - public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); } @@ -53,7 +51,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { } @Override - public boolean eval(Row i) { + public boolean eval(InternalRow i) { ${eval.code} return !${eval.isNull} && ${eval.primitive}; } @@ -61,10 +59,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { logDebug(s"Generated predicate '$predicate':\n$code") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] - (r: Row) => p.eval(r) + val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 7caf4aaab88bb..39d32b78cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -27,9 +26,10 @@ import org.apache.spark.sql.types._ abstract class BaseProject extends Projection {} /** - * Generates bytecode that produces a new [[Row]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom - * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] + * object is custom generated based on the output types of the [[Expression]] to avoid boxing of + * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { import scala.reflect.runtime.universe._ @@ -71,58 +71,64 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = ctx.nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => - s"case $i: return c$i;" - case _ => "" + val specificAccessorFunctions = ctx.primitiveTypes.map { jt => + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: return c$i;") + case _ => None }.mkString("\n ") - if (cases.count(_ != '\n') > 0) { + if (cases.length > 0) { + val getter = "get" + ctx.primitiveTypeName(jt) s""" @Override - public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public $jt $getter(int i) { if (isNullAt(i)) { - return ${ctx.defaultValue(dataType)}; + return ${ctx.defaultValue(jt)}; } switch (i) { $cases } - return ${ctx.defaultValue(dataType)}; + throw new IllegalArgumentException("Invalid index: " + i + + " in $getter"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") - val specificMutatorFunctions = ctx.nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => - s"case $i: { c$i = value; return; }" - case _ => "" - }.mkString("\n") - if (cases.count(_ != '\n') > 0) { + val specificMutatorFunctions = ctx.primitiveTypes.map { jt => + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: { c$i = value; return; }") + case _ => None + }.mkString("\n ") + if (cases.length > 0) { + val setter = "set" + ctx.primitiveTypeName(jt) s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { + public void $setter(int i, $jt value) { nullBits[i] = false; switch (i) { $cases } + throw new IllegalArgumentException("Invalid index: " + i + + " in $setter}"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = newTermName(s"c$i") + val col = s"c$i" val nonNull = e.dataType match { case BooleanType => s"$col ? 0 : 1" case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType => s"$col ^ ($col >>> 32)" + case LongType | TimestampType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case BinaryType => s"java.util.Arrays.hashCode($col)" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" @@ -135,15 +141,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val columnChecks = expressions.zipWithIndex.map { case (e, i) => s""" - if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { - return false; - } + if (nullBits[$i] != row.nullBits[$i] || + (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { + return false; + } """ }.mkString("\n") - val code = s""" - import org.apache.spark.sql.Row; + val copyColumns = expressions.zipWithIndex.map { case (e, i) => + s"""arr[$i] = c$i;""" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); } @@ -157,20 +166,20 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { - return new SpecificRow(expressions, (Row) r); + return new SpecificRow(expressions, (InternalRow) r); } } - final class SpecificRow extends ${typeOf[BaseMutableRow]} { + final class SpecificRow extends ${typeOf[MutableRow]} { $columns - public SpecificRow($exprType[] expressions, Row i) { + public SpecificRow($exprType[] expressions, InternalRow i) { $initColumns } - public int size() { return ${expressions.length};} - private boolean[] nullBits = new boolean[${expressions.length}]; + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } @@ -203,22 +212,25 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - if (row.length() != size()) return false; + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; $columnChecks return true; } return super.equals(other); } + + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${typeOf[GenericInternalRow]}(arr); + } } """ logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] + compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala similarity index 66% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 6398b8f9e4ed7..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - /** * Returns an Array containing the evaluation of all children expressions. */ @@ -27,21 +28,18 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") ArrayType( - childTypes.headOption.getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) } override def nullable: Boolean = false - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { children.map(_.eval(input)) } @@ -52,24 +50,25 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a Row containing the evaluation of all children expressions. * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[NamedExpression]) extends Expression { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.map { child => - StructField(child.name, child.dataType, child.nullable, child.metadata) + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } } StructType(fields) } override def nullable: Boolean = false - override def eval(input: Row): Any = { - Row(children.map(_.eval(input)): _*) + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 1a5cde26c9b13..1d7393d3d91f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{BooleanType, DataType} @@ -42,7 +43,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def dataType: DataType = trueValue.dataType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) } else { @@ -137,7 +138,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { } /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val len = branchesArr.length var i = 0 // If all branches fail and an elseVal is not provided, the whole statement @@ -229,7 +230,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluatedKey = key.eval(input) val len = branchesArr.length var i = 0 @@ -261,7 +262,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW ${cond.code} if (${keyEval.isNull} && ${cond.isNull} || !${keyEval.isNull} && !${cond.isNull} - && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 8ab6d977dd3a6..f5c2dde191cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ -/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class UnscaledValue(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"UnscaledValue($child)" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val childResult = child.eval(input) if (childResult == null) { null @@ -42,15 +44,17 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } -/** Create a Decimal from an unscaled Long value */ +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"MakeDecimal($child,$precision,$scale)" - override def eval(input: Row): Decimal = { + override def eval(input: InternalRow): Decimal = { val childResult = child.eval(input) if (childResult == null) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b6191eafba71b..7a42a1d310581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -53,13 +55,13 @@ abstract class Generator extends Expression { def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ - override def eval(input: Row): TraversableOnce[Row] + override def eval(input: InternalRow): TraversableOnce[InternalRow] /** * Notifies that there are no more rows to process, clean up code, and additional * rows can be made here. */ - def terminate(): TraversableOnce[Row] = Nil + def terminate(): TraversableOnce[InternalRow] = Nil } /** @@ -67,22 +69,22 @@ abstract class Generator extends Expression { */ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], - function: Row => TraversableOnce[Row], + function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator { @transient private[this] var inputRow: InterpretedProjection = _ - @transient private[this] var convertToScala: (Row) => Row = _ + @transient private[this] var convertToScala: (InternalRow) => Row = _ private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) CatalystTypeConverters.createToScalaConverter(inputSchema) - }.asInstanceOf[(Row => Row)] + }.asInstanceOf[InternalRow => Row] } - override def eval(input: Row): TraversableOnce[Row] = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { if (inputRow == null) { initializeConverters() } @@ -99,23 +101,29 @@ case class UserDefinedGenerator( case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { - override lazy val resolved = - child.resolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to function explode should be array or map type, not ${child.dataType}") + } + } override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - override def eval(input: Row): TraversableOnce[Row] = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) + if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] - if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) } + if (inputMap == null) Nil + else inputMap.map { case (k, v) => InternalRow(k, v) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 297b35b4da94c..479224af5627a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object Literal { def apply(v: Any): Literal = v match { @@ -32,13 +34,13 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(UTF8String(s), StringType) + case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) - case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) + case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) + case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) case _ => @@ -86,21 +88,20 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres case _ => false } - override def eval(input: Row): Any = value + override def eval(input: InternalRow): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - ev.primitive = ctx.defaultValue(dataType) - "" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};" } else { dataType match { case BooleanType => ev.isNull = "false" ev.primitive = value.toString "" - case FloatType => // This must go before NumericType + case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -109,7 +110,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}f" "" } - case DoubleType => // This must go before NumericType + case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -118,15 +119,18 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}" "" } - - case ByteType | ShortType => // This must go before NumericType + case ByteType | ShortType => ev.isNull = "false" ev.primitive = s"(${ctx.javaType(dataType)})$value" "" - case dt: NumericType if !dt.isInstanceOf[DecimalType] => + case IntegerType | DateType => ev.isNull = "false" ev.primitive = value.toString "" + case TimestampType | LongType => + ev.isNull = "false" + ev.primitive = s"${value}L" + "" // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) @@ -139,9 +143,9 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { - def update(expression: Expression, input: Row): Unit = { + def update(expression: Expression, input: InternalRow): Unit = { value = expression.eval(input) } - override def eval(input: Row): Any = value + override def eval(input: InternalRow): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7dacb6a9b47b6..a022f3727bd58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -17,25 +17,54 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Long => JLong} +import java.util.Arrays + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A leaf expression specifically for math constants. Math constants expect no input. + * @param c The math constant. + * @param name The short name of the function + */ +abstract class LeafMathExpression(c: Double, name: String) + extends LeafExpression with Serializable { + self: Product => + + override def dataType: DataType = DoubleType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def toString: String = s"$name()" + + override def eval(input: InternalRow): Any = c + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; + """ + } +} /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * @param f The math function. * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null @@ -70,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -78,7 +107,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = DoubleType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { null @@ -98,6 +127,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Leaf math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +case class EulerNumber() extends LeafMathExpression(math.E, "E") + +case class Pi() extends LeafMathExpression(math.Pi, "PI") + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// // Unary math functions @@ -126,6 +165,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log2(child: Expression) + extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } +} + case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") @@ -140,6 +196,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") + case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") @@ -152,6 +210,102 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia override def funcName: String = "toRadians" } +case class Bin(child: Expression) + extends UnaryExpression with Serializable with AutoCastInputTypes { + + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c) => + s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + } +} + + +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) extends UnaryExpression with Serializable { + + override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] + || child.dataType.isInstanceOf[IntegerType] + || child.dataType.isInstanceOf[LongType] + || child.dataType.isInstanceOf[BinaryType] + || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case IntegerType => hex(num.asInstanceOf[Integer].toLong) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String]) + } + } + } + + /** + * Converts every character in s to two hex digits. + */ + private def hex(str: UTF8String): UTF8String = { + hex(str.getBytes) + } + + private def hex(bytes: Array[Byte]): UTF8String = { + doHex(bytes, bytes.length) + } + + private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + val value = new Array[Byte](length * 2) + var i = 0 + while (i < length) { + value(i * 2) = Character.toUpperCase(Character.forDigit( + (bytes(i) & 0xF0) >>> 4, 16)).toByte + value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( + bytes(i) & 0x0F, 16)).toByte + i += 1 + } + UTF8String.fromBytes(value) + } + + private def hex(num: Long): UTF8String = { + // Extract the hex digits of num into value[] from right to left + val value = new Array[Byte](16) + var numBuf = num + var len = 0 + do { + len += 1 + value(value.length - len) = Character.toUpperCase(Character + .forDigit((numBuf & 0xF).toInt, 16)).toByte + numBuf >>>= 4 + } while (numBuf != 0) + UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -163,7 +317,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { null @@ -189,9 +343,6 @@ case class Atan2(left: Expression, right: Expression) } } -case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") - case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -202,3 +353,30 @@ case class Pow(left: Expression, right: Expression) """ } } + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Logarithm(left: Expression, right: Expression) + extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + + /** + * Natural log, i.e. using e as the base. + */ + def this(child: Expression) = { + this(EulerNumber(), child) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val logCode = if (left.isInstanceOf[EulerNumber]) { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + } + logCode + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala new file mode 100644 index 0000000000000..27805bff293f4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException + +import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A function that calculates an MD5 128-bit checksum and returns it as a hex string + * For input of type [[BinaryType]] + */ +case class Md5(child: Expression) + extends UnaryExpression with AutoCastInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with AutoCastInputTypes { + + override def dataType: DataType = StringType + + override def toString: String = s"SHA2($left, $right)" + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val bitLength = evalE2.asInstanceOf[Int] + val input = evalE1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + if (${eval2.primitive} == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update(${eval1.primitive}); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } +} + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2e4b9ba678433..6f56a9ec7beb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -112,9 +112,12 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) - override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] - override def eval(input: Row): Any = child.eval(input) + override def eval(input: InternalRow): Any = child.eval(input) + + override def isThreadSafe: Boolean = child.isThreadSafe override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) @@ -230,7 +233,7 @@ case class AttributeReference( } // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -252,12 +255,12 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: Row): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } object VirtualColumn { val groupingIdName: String = "grouping__id" - def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index c2d1a4eadae29..5d5911403ece1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,34 +17,32 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - override def nullable: Boolean = !children.exists(!_.nullable) + override def nullable: Boolean = children.forall(_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) - // Only resolved if all the children are of the same type. - override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + } + } override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = if (resolved) { - children.head.dataType - } else { - val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") - throw new UnresolvedException( - this, s"Coalesce cannot have children of different types. $childTypes") - } + override def dataType: DataType = children.head.dataType - override def eval(input: Row): Any = { - var i = 0 + override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator while (childIterator.hasNext && result == null) { @@ -53,6 +51,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } + override def isThreadSafe: Boolean = children.forall(_.isThreadSafe) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; @@ -73,11 +73,10 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } -case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable: Boolean = child.foldable +case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) == null } @@ -91,12 +90,11 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def toString: String = s"IS NULL $child" } -case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable: Boolean = child.foldable +case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false override def toString: String = s"IS NOT NULL $child" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) != null } @@ -118,7 +116,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate private[this] val childrenArray = children.toArray - override def eval(input: Row): Boolean = { + override def eval(input: InternalRow): Boolean = { var numNonNulls = 0 var i = 0 while (i < childrenArray.length && numNonNulls < n) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index fbc97b2e75312..d24d74e7b82ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -49,30 +49,30 @@ package org.apache.spark.sql.catalyst */ package object expressions { - type Row = org.apache.spark.sql.Row + type InternalRow = org.apache.spark.sql.catalyst.InternalRow - val Row = org.apache.spark.sql.Row + val InternalRow = org.apache.spark.sql.catalyst.InternalRow /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound - * to that schema. + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. */ - abstract class Projection extends (Row => Row) + abstract class Projection extends (InternalRow => InternalRow) /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound - * to that schema. + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. * * In contrast to a normal projection, a MutableProjection reuses the same underlying row object * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after + * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call + * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. */ abstract class MutableProjection extends Projection { - def currentValue: Row + def currentValue: InternalRow /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3cbdfdfb13847..386cf6a8df6df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = create(BindReferences.bindReference(expression, inputSchema)) - def create(expression: Expression): (Row => Boolean) = { - (r: Row) => expression.eval(r).asInstanceOf[Boolean] + def create(expression: Expression): (InternalRow => Boolean) = { + (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } @@ -70,14 +70,14 @@ trait PredicateHelper { } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) match { case null => null case b: Boolean => !b @@ -98,7 +98,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) list.exists(e => e.eval(input) == evaluatedValue) } @@ -117,19 +117,19 @@ case class InSet(value: Expression, hset: Set[Any]) override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { hset.contains(value.eval(input)) } } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == false) { false @@ -172,13 +172,13 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == true) { true @@ -235,7 +235,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { protected def checkTypesInternal(t: DataType): TypeCheckResult - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { null @@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { - (c1, c3) => s"$c1 $symbol $c3" - }) - case TimestampType => - // java.sql.Timestamp does not have compare() - super.genCode(ctx, ev) - case other => defineCodeGen (ctx, ev, { - (c1, c2) => s"$c1.compare($c2) $symbol 0" - }) + if (ctx.isPrimitiveType(left.dataType)) { + // faster version + defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } @@ -271,6 +266,15 @@ private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } +/** An extractor that matches both standard 3VL equality and null-safe equality. */ +private[sql] object Equality { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case EqualTo(l, r) => Some((l, r)) + case EqualNullSafe(l, r) => Some((l, r)) + case _ => None + } +} + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" @@ -280,8 +284,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) + defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -292,7 +297,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) val r = right.eval(input) if (l == null && r == null) { @@ -307,7 +312,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive) ev.isNull = "false" eval1.code + eval2.code + s""" boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index b2647124c4e49..45588bacd2e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -36,7 +37,11 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + @transient protected lazy val partitionId = TaskContext.get() match { + case null => 0 + case _ => TaskContext.get().partitionId() + } + @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) override def deterministic: Boolean = false @@ -46,11 +51,25 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { - override def eval(input: Row): Double = rng.nextDouble() +case class Rand(seed: Long) extends RDG(seed) { + override def eval(input: InternalRow): Double = rng.nextDouble() + + def this() = this(Utils.random.nextLong()) + + def this(seed: Expression) = this(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) } /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { - override def eval(input: Row): Double = rng.nextGaussian() +case class Randn(seed: Long) extends RDG(seed) { + override def eval(input: InternalRow): Double = rng.nextGaussian() + + def this() = this(Utils.random.nextLong()) + + def this(seed: Expression) = this(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5fd892c42e69c..dd5f2ed2d382e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,32 +17,46 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.unsafe.types.UTF8String /** - * An extended interface to [[Row]] that allows the values for each column to be updated. Setting - * a value through a primitive function implicitly marks that column as not null. + * An extended interface to [[InternalRow]] that allows the values for each column to be updated. + * Setting a value through a primitive function implicitly marks that column as not null. */ -trait MutableRow extends Row { +abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(ordinal: Int, value: Any) - - def setInt(ordinal: Int, value: Int) - def setLong(ordinal: Int, value: Long) - def setDouble(ordinal: Int, value: Double) - def setBoolean(ordinal: Int, value: Boolean) - def setShort(ordinal: Int, value: Short) - def setByte(ordinal: Int, value: Byte) - def setFloat(ordinal: Int, value: Float) - def setString(ordinal: Int, value: String) - // TODO(davies): add setDate() and setDecimal() + def update(i: Int, value: Any) + + // default implementation (slow) + def setInt(i: Int, value: Int): Unit = { update(i, value) } + def setLong(i: Int, value: Long): Unit = { update(i, value) } + def setDouble(i: Int, value: Double): Unit = { update(i, value) } + def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } + def setShort(i: Int, value: Short): Unit = { update(i, value) } + def setByte(i: Int, value: Byte): Unit = { update(i, value) } + def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setString(i: Int, value: String): Unit = { + update(i, UTF8String.fromString(value)) + } + + override def copy(): InternalRow = { + val arr = new Array[Any](length) + var i = 0 + while (i < length) { + arr(i) = get(i) + i += 1 + } + new GenericInternalRow(arr) + } } /** * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. */ -object EmptyRow extends Row { +object EmptyRow extends InternalRow { override def apply(i: Int): Any = throw new UnsupportedOperationException override def toSeq: Seq[Any] = Seq.empty override def length: Int = 0 @@ -56,120 +70,57 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - override def copy(): Row = this + override def copy(): InternalRow = this } /** - * A row implementation that uses an array of objects as the underlying storage. Note that, while - * the array is not copied, and thus could technically be mutated after creation, this is not - * allowed. + * A row implementation that uses an array of objects as the underlying storage. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row { - /** No-arg constructor for serialization. */ - protected def this() = this(null) +trait ArrayBackedRow { + self: Row => - def this(size: Int) = this(new Array[Any](size)) + protected val values: Array[Any] override def toSeq: Seq[Any] = values.toSeq - override def length: Int = values.length + def length: Int = values.length override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int): Boolean = values(i) == null - - override def getInt(i: Int): Int = { - if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") - values(i).asInstanceOf[Int] - } - - override def getLong(i: Int): Long = { - if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") - values(i).asInstanceOf[Long] - } - - override def getDouble(i: Int): Double = { - if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") - values(i).asInstanceOf[Double] - } - - override def getFloat(i: Int): Float = { - if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") - values(i).asInstanceOf[Float] - } + def setNullAt(i: Int): Unit = { values(i) = null} - override def getBoolean(i: Int): Boolean = { - if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") - values(i).asInstanceOf[Boolean] - } - - override def getShort(i: Int): Short = { - if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") - values(i).asInstanceOf[Short] - } - - override def getByte(i: Int): Byte = { - if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") - values(i).asInstanceOf[Byte] - } - - override def getString(i: Int): String = { - values(i) match { - case null => null - case s: String => s - case utf8: UTF8String => utf8.toString - } - } - - // TODO(davies): add getDate and getDecimal + def update(i: Int, value: Any): Unit = { values(i) = value } +} - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - var i = 0 - while (i < values.length) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - apply(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } + def this(size: Int) = this(new Array[Any](size)) + // This is used by test or outside override def equals(o: Any): Boolean = o match { - case other: Row => - if (values.length != other.length) { - return false - } - + case other: Row if other.length == length => var i = 0 - while (i < values.length) { + while (i < length) { if (isNullAt(i) != other.isNullAt(i)) { return false } - if (apply(i) != other.apply(i)) { + val equal = (apply(i), other.apply(i)) match { + case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) + case (a, b) => a == b + } + if (!equal) { return false } i += 1 } true - case _ => false } @@ -185,34 +136,35 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) override def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { +/** + * A internal row implementation that uses an array of objects as the underlying storage. + * Note that, while the array is not copied, and thus could technically be mutated after creation, + * this is not allowed. + */ +class GenericInternalRow(protected[sql] val values: Array[Any]) + extends InternalRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} - override def setNullAt(i: Int): Unit = { values(i) = null } + override def copy(): InternalRow = this +} - override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } +class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } + def this(size: Int) = this(new Array[Any](size)) - override def copy(): Row = new GenericRow(values.clone()) + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } - -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { +class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) - def compare(a: Row, b: Row): Int = { + def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 while (i < ordering.size) { val order = ordering(i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 2bcb960e9177e..efc6f50b78943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -57,7 +57,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { new OpenHashSet[Any]() } @@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { /** * Adds an item to a set. * For performance, this expression mutates its input during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { @@ -85,9 +87,9 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = set.dataType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -128,16 +130,18 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { /** * Combines the elements of two sets. * For performance, this expression mutates its left input set during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable || right.nullable - override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = left.dataType override def symbol: String = "++=" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres /** * Returns the number of elements in the input set. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CountSet(child: Expression) extends UnaryExpression { @@ -183,7 +189,7 @@ case class CountSet(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index aae122a981e47..ce184e4f32f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,8 +22,9 @@ import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends AutoCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -48,7 +49,7 @@ trait StringRegexExpression extends ExpectsInputTypes { protected def pattern(str: String) = if (cache == null) compile(str) else cache - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == null) { null @@ -110,17 +111,15 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } -trait CaseConversionExpression extends ExpectsInputTypes { +trait CaseConversionExpression extends AutoCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = StringType override def expectedChildTypes: Seq[DataType] = Seq(StringType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) if (evaluated == null) { null @@ -159,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends AutoCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -168,7 +167,7 @@ trait StringComparison extends ExpectsInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) if(leftEval == null) { null @@ -222,11 +221,16 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with AutoCastInputTypes { + + def this(str: Expression, pos: Expression) = { + this(str, pos, Literal(Integer.MAX_VALUE)) + } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") @@ -260,7 +264,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) (start, end) } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val string = str.eval(input) val po = pos.eval(input) val ln = len.eval(input) @@ -276,7 +280,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) ba.slice(st, end) case s: UTF8String => val (st, end) = slicePos(start, length, () => s.length()) - s.slice(st, end) + s.substring(st, end) } } } @@ -287,3 +291,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression) case _ => s"SUBSTR($str, $pos, $len)" } } + +/** + * A function that return the length of the given string expression. + */ +case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) null else string.asInstanceOf[UTF8String].length + } + + override def toString: String = s"length($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).length()") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 82c4d462cc322..12023ad311dc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{NumericType, DataType} +import org.apache.spark.sql.types.{DataType, NumericType} /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -69,12 +68,13 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = - childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] + childrenResolved && checkInputDataTypes().isSuccess && + frameSpecification.isInstanceOf[SpecifiedWindowFrame] override def toString: String = simpleString - override def eval(input: Row): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -259,7 +259,7 @@ trait WindowFunction extends Expression { def reset(): Unit - def prepareInputParameters(input: Row): AnyRef + def prepareInputParameters(input: InternalRow): AnyRef def update(input: AnyRef): Unit @@ -286,7 +286,7 @@ case class UnresolvedWindowFunction( throw new UnresolvedException(this, "init") override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: Row): AnyRef = + override def prepareInputParameters(input: InternalRow): AnyRef = throw new UnresolvedException(this, "prepareInputParameters") override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") @@ -297,7 +297,7 @@ case class UnresolvedWindowFunction( override def get(index: Int): Any = throw new UnresolvedException(this, "get") // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -316,7 +316,7 @@ case class UnresolvedWindowExpression( override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -327,7 +327,7 @@ case class WindowExpression( override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil - override def eval(input: Row): Any = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def dataType: DataType = windowFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c16f08d389955..bfd24287c9645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -38,21 +38,24 @@ object DefaultOptimizer extends Optimizer { EliminateSubQueries) :: Batch("Distinct", FixedPoint(100), ReplaceDistinctWithAggregate) :: - Batch("Operator Reordering", FixedPoint(100), - UnionPushdown, - CombineFilters, - PushPredicateThroughProject, + Batch("Operator Optimizations", FixedPoint(100), + // Operator push down + UnionPushDown, PushPredicateThroughJoin, + PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, + // Operator combine ProjectCollapsing, - CombineLimits) :: - Batch("ConstantFolding", FixedPoint(100), + CombineFilters, + CombineLimits, + // Constant folding NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, + RemovePositive, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer { } /** - * Pushes operations to either side of a Union. - */ -object UnionPushdown extends Rule[LogicalPlan] { + * Pushes operations to either side of a Union. + */ +object UnionPushDown extends Rule[LogicalPlan] { /** - * Maps Attributes from the left side to the corresponding Attribute on the right side. - */ - def buildRewrites(union: Union): AttributeMap[Attribute] = { + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + private def buildRewrites(union: Union): AttributeMap[Attribute] = { assert(union.left.output.size == union.right.output.size) AttributeMap(union.left.output.zip(union.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. - */ - def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) } @@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] { } } - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -117,10 +119,13 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin - * - Performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) + if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) @@ -155,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) + // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // push down project if possible when the child is sort + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) @@ -177,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[Project]] operators into one, merging the - * expressions into one single expression. + * Combines two adjacent [[Project]] operators into one and perform alias substitution, + * merging the expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { @@ -218,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] { object LikeSimplification extends Rule[LogicalPlan] { // if guards below protect from escapes on trailing %. // Cases like "something\%" are not optimized, but this does not affect correctness. - val startsWith = "([^_%]+)%".r - val endsWith = "%([^_%]+)".r - val contains = "%([^_%]+)%".r - val equalTo = "([^_%]*)".r + private val startsWith = "([^_%]+)%".r + private val endsWith = "%([^_%]+)".r + private val contains = "%([^_%]+)%".r + private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(utf, StringType)) => @@ -493,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } @@ -633,6 +639,15 @@ object SimplifyCasts extends Rule[LogicalPlan] { } } +/** + * Removes [[UnaryPositive]] identify function + */ +object RemovePositive extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case UnaryPositive(child) => child + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. @@ -669,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ - val MAX_DOUBLE_DIGITS = 15 + private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a8846303..179a348d5baac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = + val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) - }.toMap + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -156,21 +156,15 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } - namedGroupingExpressions - .find { case (k, v) => k semanticEquals trimmed } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals e => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + val partialComputation = namedGroupingExpressions.map(_._2) ++ + partialEvaluations.values.flatMap(_.partialEvaluations) - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index eff5c61644944..2f545bb432165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionDown(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } @@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionUp(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e070f0ff307..1868f119f0e97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructType, StructField} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} +import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) @@ -32,11 +31,11 @@ object LocalRelation { def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) - LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) + LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } -case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil) +case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dba69659afc80..b009a200b920f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] * should return `false`). */ - lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved override protected def statePrefix = if (!resolved) "'" else super.statePrefix /** * Returns true if all its children of this query plan have been resolved. */ - def childrenResolved: Boolean = !children.exists(!_.resolved) + def childrenResolved: Boolean = children.forall(_.resolved) /** * Returns true when the given logical plan will return the same results as this logical plan. * - * Since its likely undecideable to generally determine if two given plans will produce the same + * Since its likely undecidable to generally determine if two given plans will produce the same * results, it is okay for this function to return false, even if the results are actually * the same. Such behavior will not affect correctness, only the application of performance * enhancements like caching. However, it is not acceptable to return true if the results could @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { val input = children.flatMap(_.output) productIterator.map { // Children are checked using sameResult above. - case tn: TreeNode[_] if children contains tn => null + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) case s: Option[_] => s.map { case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) @@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolveChildren( nameParts: Seq[String], - resolver: Resolver, - throwErrors: Boolean = false): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) + resolver: Resolver): Option[NamedExpression] = + resolve(nameParts, children.flatMap(_.output), resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolve( nameParts: Seq[String], - resolver: Resolver, - throwErrors: Boolean = false): Option[NamedExpression] = - resolve(nameParts, output, resolver, throwErrors) + resolver: Resolver): Option[NamedExpression] = + resolve(nameParts, output, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(parseAttributeName(name), resolver, true) + resolve(parseAttributeName(name), output, resolver) } /** @@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { protected def resolve( nameParts: Seq[String], input: Seq[Attribute], - resolver: Resolver, - throwErrors: Boolean): Option[NamedExpression] = { + resolver: Resolver): Option[NamedExpression] = { // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list @@ -254,19 +251,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - try { - // The foldLeft adds GetFields for every remaining parts of the identifier, - // and aliases it with the last part of the identifier. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add GetField("c", GetField("b", a)), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) - } catch { - case a: AnalysisException if !throwErrors => None - } + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and wrap it with UnresolvedAlias which will be removed later. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as + // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => + ExtractValue(expr, Literal(fieldName), resolver)) + Some(UnresolvedAlias(fieldExprs)) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e77e5c27b687a..fae339808c233 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashSet case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -130,6 +131,14 @@ case class Join( } } +/** + * A hint for the optimizer that we should broadcast the `child` if used in a join operator. + */ +case class BroadcastHint(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + + case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output } @@ -220,28 +229,82 @@ case class Window( /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for a input row. - * @param projections The group of expressions, all of the group expressions should - * output the same schema specified by the parameter `output` - * @param output The output Schema + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions * @param child Child operator */ case class Expand( - projections: Seq[GroupExpression], - output: Seq[Attribute], + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + val projections: Seq[Seq[Expression]] = expand() + + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] + + bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + + val substitution = (child.output :+ gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal.create(null, expr.dataType) + case x if x == gid => + // replace the groupingId with concrete value (the bit mask) + Literal.create(bitmask, IntegerType) + }) + + result += substitution + } + + result.toSeq + } + + override def output: Seq[Attribute] = { + child.output :+ gid + } } trait GroupingAnalytics extends UnaryNode { self: Product => - def gid: AttributeReference def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } /** @@ -256,17 +319,16 @@ trait GroupingAnalytics extends UnaryNode { * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. - * The associated output will be one of the value in `bitmasks` */ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -276,15 +338,15 @@ case class GroupingSets( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -295,15 +357,15 @@ case class Cube( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 80ba57a082a60..42dead7c28425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -169,7 +170,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def keyExpressions: Seq[Expression] = expressions - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -213,6 +214,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def eval(input: Row): Any = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 36d005d0e1684..09f6c6b0ec423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val origin: Origin = CurrentOrigin.get - /** Returns a Seq of the children of this node */ + /** + * Returns a Seq of the children of this node. + * Children should not change. Immutability required for containsChild optimization + */ def children: Seq[BaseType] + lazy val containsChild: Set[TreeNode[_]] = children.toSet + /** * Faster version of equality which short-circuits when two treeNodes are the same instance. * We don't just override Object.equals, as doing so prevents the scala compiler from @@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def mapChildren(f: BaseType => BaseType): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { arg @@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case nonChild: AnyRef => nonChild case null => null } - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -238,7 +243,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -246,7 +251,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -257,7 +262,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -280,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildrenUp(rule); + val afterRuleOnChildren = transformChildrenUp(rule) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -295,7 +300,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -303,7 +308,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -314,7 +319,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -344,11 +349,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { - val defaultCtor = - getClass.getConstructors - .find(_.getParameterTypes.size != 0) - .headOption - .getOrElse(sys.error(s"No valid constructor for $nodeName")) + val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) + if (ctors.isEmpty) { + sys.error(s"No valid constructor for $nodeName") + } + val defaultCtor = ctors.maxBy(_.getParameterTypes.size) try { CurrentOrigin.withOrigin(origin) { @@ -383,7 +388,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if children contains tn => Nil + case tn: TreeNode[_] if containsChild(tn) => Nil case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala new file mode 100644 index 0000000000000..640e67e2ecd76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.{Date, Timestamp} +import java.text.{DateFormat, SimpleDateFormat} +import java.util.{Calendar, TimeZone} + +/** + * Helper functions for converting between internal and external date and time representations. + * Dates are exposed externally as java.sql.Date and are represented internally as the number of + * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp + * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * precision. + */ +object DateTimeUtils { + final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L + + // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian + final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 + final val SECONDS_PER_DAY = 60 * 60 * 24L + final val HUNDRED_NANOS_PER_SECOND = 1000L * 1000L * 10L + final val NANOS_PER_SECOND = HUNDRED_NANOS_PER_SECOND * 100 + + + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. + private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { + override protected def initialValue: TimeZone = { + Calendar.getInstance.getTimeZone + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") + } + } + + + // we should use the exact day as Int, for example, (year, month, day) -> day + def millisToDays(millisLocal: Long): Int = { + ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + } + + // reverse of millisToDays + def daysToMillis(days: Int): Long = { + val millisUtc = days.toLong * MILLIS_PER_DAY + millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) + } + + def dateToString(days: Int): String = + threadLocalDateFormat.get.format(toJavaDate(days)) + + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(num100ns: Long): String = { + val ts = toJavaTimestamp(num100ns) + val timestampString = ts.toString + val formatted = threadLocalTimestampFormat.get.format(ts) + + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + Timestamp.valueOf(s) + } else { + Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + + /** + * Returns the number of days since epoch from from java.sql.Date. + */ + def fromJavaDate(date: Date): Int = { + millisToDays(date.getTime) + } + + /** + * Returns a java.sql.Date from number of days since epoch. + */ + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(daysToMillis(daysSinceEpoch)) + } + + /** + * Returns a java.sql.Timestamp from number of 100ns since epoch. + */ + def toJavaTimestamp(num100ns: Long): Timestamp = { + // setNanos() will overwrite the millisecond part, so the milliseconds should be + // cut off at seconds + var seconds = num100ns / HUNDRED_NANOS_PER_SECOND + var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + // setNanos() can not accept negative value + if (nanos < 0) { + nanos += HUNDRED_NANOS_PER_SECOND + seconds -= 1 + } + val t = new Timestamp(seconds * 1000) + t.setNanos(nanos.toInt * 100) + t + } + + /** + * Returns the number of 100ns since epoch from java.sql.Timestamp. + */ + def fromJavaTimestamp(t: Timestamp): Long = { + if (t != null) { + t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + } else { + 0L + } + } + + /** + * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * and nanoseconds in a day + */ + def fromJulianDay(day: Int, nanoseconds: Long): Long = { + // use Long to avoid rounding errors + val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 + seconds * HUNDRED_NANOS_PER_SECOND + nanoseconds / 100L + } + + /** + * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + */ + def toJulianDay(num100ns: Long): (Int, Long) = { + val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 + val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH + val secondsInDay = seconds % SECONDS_PER_DAY + val nanos = (num100ns % HUNDRED_NANOS_PER_SECOND) * 100L + (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala deleted file mode 100644 index ad649acf536f9..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import java.sql.Date -import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} - -import org.apache.spark.sql.catalyst.expressions.Cast - -/** - * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date - */ -object DateUtils { - private val MILLIS_PER_DAY = 86400000 - - // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { - override protected def initialValue: TimeZone = { - Calendar.getInstance.getTimeZone - } - } - - private def javaDateToDays(d: Date): Int = { - millisToDays(d.getTime) - } - - // we should use the exact day as Int, for example, (year, month, day) -> day - def millisToDays(millisLocal: Long): Int = { - ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt - } - - private def toMillisSinceEpoch(days: Int): Long = { - val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) - } - - def fromJavaDate(date: java.sql.Date): Int = { - javaDateToDays(date) - } - - def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { - new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) - } - - def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala new file mode 100644 index 0000000000000..191d5e6399fc9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + + def remove(key: String): Option[T] = base.remove(normalizer(key)) + + def iterator: Iterator[(String, T)] = base.toIterator +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0bb12d2039ffc..8656cc334d09f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -48,9 +48,26 @@ object TypeUtils { } } + def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { + if (types.distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] def getOrdering(t: DataType): Ordering[Any] = t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] + + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index a581a9e9468ef..9b58601e5e6ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.util.TypeUtils /** @@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType { private[sql] val ordering = new Ordering[InternalType] { def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length + TypeUtils.compareBinary(x, y) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 04f3379afb38d..6b43224feb1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -44,7 +44,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | - "(?i)bigint".r ^^^ LongType | + "(?i)(?:bigint|long)".r ^^^ LongType | "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index eb3c58c37f308..5a169488c97eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{MathContext, RoundingMode} + import org.apache.spark.annotation.DeveloperApi /** @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(longVal) + this.decimalVal = BigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal + decimalVal(MathContext.UNLIMITED) } else { - BigDecimal(longVal, _scale) + BigDecimal(longVal, _scale)(MathContext.UNLIMITED) } } @@ -263,8 +265,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + def / (that: Decimal): Decimal = { + if (that.isZero) { + null + } else { + // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide + // with specified ROUNDING_MODE. + Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) + } + } def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 407dc27326c2e..18cdfa7238f39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - +case class PrecisionInfo(precision: Int, scale: Int) { + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } +} /** * :: DeveloperApi :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 134ab0af4e0de..1e9476ad06656 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.unsafe.types.UTF8String /** * :: DeveloperApi :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 193c08a4d0df7..2db0a359e9db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -94,7 +94,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ - protected def this() = this(null) + def this() = this(Array.empty[StructField]) /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -103,6 +103,108 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + /** + * Creates a new [[StructType]] by adding a new field. + * {{{ + * val struct = (new StructType) + * .add(StructField("a", IntegerType, true)) + * .add(StructField("b", LongType, false)) + * .add(StructField("c", StringType, true)) + *}}} + */ + def add(field: StructField): StructType = { + StructType(fields :+ field) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType) + * .add("b", LongType) + * .add("c", StringType) + */ + def add(name: String, dataType: DataType): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType, true) + * .add("b", LongType, false) + * .add("c", StringType, true) + */ + def add(name: String, dataType: DataType, nullable: Boolean): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata. + * {{{ + * val struct = (new StructType) + * .add("a", IntegerType, true, Metadata.empty) + * .add("b", LongType, false, Metadata.empty) + * .add("c", StringType, true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int") + * .add("b", "long") + * .add("c", "string") + * }}} + */ + def add(name: String, dataType: String): StructType = { + add(name, DataTypeParser.parse(dataType), nullable = true, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int", true) + * .add("b", "long", false) + * .add("c", "string", true) + * }}} + */ + def add(name: String, dataType: String, nullable: Boolean): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata where the + * dataType is specified as a String. + * {{{ + * val struct = (new StructType) + * .add("a", "int", true, Metadata.empty) + * .add("b", "long", false, Metadata.empty) + * .add("c", "string", true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: String, + nullable: Boolean, + metadata: Metadata): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, metadata) + } + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index aebabfc475925..a558641fcfed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.sql.Timestamp - import scala.math.Ordering import scala.reflect.runtime.universe.typeTag @@ -38,18 +36,16 @@ class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp + private[sql] type InternalType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the TimestampType is 12 bytes. */ - override def defaultSize: Int = 12 + override def defaultSize: Int = 8 private[spark] override def asNullable: TimestampType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala deleted file mode 100644 index f5d8fcced362b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.types - -import java.util.Arrays - -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * A UTF-8 String, as internal representation of StringType in SparkSQL - * - * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, - * search, see http://en.wikipedia.org/wiki/UTF-8 for details. - * - * Note: This is not designed for general use cases, should not be used outside SQL. - */ -@DeveloperApi -final class UTF8String extends Ordered[UTF8String] with Serializable { - - private[this] var bytes: Array[Byte] = _ - - /** - * Update the UTF8String with String. - */ - def set(str: String): UTF8String = { - bytes = str.getBytes("utf-8") - this - } - - /** - * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 - */ - def set(bytes: Array[Byte]): UTF8String = { - this.bytes = bytes - this - } - - /** - * Return the number of bytes for a code point with the first byte as `b` - * @param b The first byte of a code point - */ - @inline - private[this] def numOfBytes(b: Byte): Int = { - val offset = (b & 0xFF) - 192 - if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 - } - - /** - * Return the number of code points in it. - * - * This is only used by Substring() when `start` is negative. - */ - def length(): Int = { - var len = 0 - var i: Int = 0 - while (i < bytes.length) { - i += numOfBytes(bytes(i)) - len += 1 - } - len - } - - def getBytes: Array[Byte] = { - bytes - } - - /** - * Return a substring of this, - * @param start the position of first code point - * @param until the position after last code point - */ - def slice(start: Int, until: Int): UTF8String = { - if (until <= start || start >= bytes.length || bytes == null) { - new UTF8String - } - - var c = 0 - var i: Int = 0 - while (c < start && i < bytes.length) { - i += numOfBytes(bytes(i)) - c += 1 - } - var j = i - while (c < until && j < bytes.length) { - j += numOfBytes(bytes(j)) - c += 1 - } - UTF8String(Arrays.copyOfRange(bytes, i, j)) - } - - def contains(sub: UTF8String): Boolean = { - val b = sub.getBytes - if (b.length == 0) { - return true - } - var i: Int = 0 - while (i <= bytes.length - b.length) { - // In worst case, it's O(N*K), but should works fine with SQL - if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { - return true - } - i += 1 - } - false - } - - def startsWith(prefix: UTF8String): Boolean = { - val b = prefix.getBytes - if (b.length > bytes.length) { - return false - } - Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) - } - - def endsWith(suffix: UTF8String): Boolean = { - val b = suffix.getBytes - if (b.length > bytes.length) { - return false - } - Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) - } - - def toUpperCase(): UTF8String = { - // upper case depends on locale, fallback to String. - UTF8String(toString().toUpperCase) - } - - def toLowerCase(): UTF8String = { - // lower case depends on locale, fallback to String. - UTF8String(toString().toLowerCase) - } - - override def toString(): String = { - new String(bytes, "utf-8") - } - - override def clone(): UTF8String = new UTF8String().set(this.bytes) - - override def compare(other: UTF8String): Int = { - var i: Int = 0 - val b = other.getBytes - while (i < bytes.length && i < b.length) { - val res = bytes(i).compareTo(b(i)) - if (res != 0) return res - i += 1 - } - bytes.length - b.length - } - - override def compareTo(other: UTF8String): Int = { - compare(other) - } - - override def equals(other: Any): Boolean = other match { - case s: UTF8String => - Arrays.equals(bytes, s.getBytes) - case s: String => - // This is only used for Catalyst unit tests - // fail fast - bytes.length >= s.length && length() == s.length && toString() == s - case _ => - false - } - - override def hashCode(): Int = { - Arrays.hashCode(bytes) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object UTF8String { - // number of tailing bytes in a UTF8 sequence for a code point - // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 - private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6, 6, 6) - - /** - * Create a UTF-8 String from String - */ - def apply(s: String): UTF8String = { - if (s != null) { - new UTF8String().set(s) - } else { - null - } - } - - /** - * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 - */ - def apply(bytes: Array[Byte]): UTF8String = { - if (bytes != null) { - new UTF8String().set(bytes) - } else { - null - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 9a24b23024e18..b4b00f558463f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,6 @@ import java.math.BigInteger import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ case class PrimitiveData( @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { } class ScalaReflectionSuite extends SparkFunSuite { - import ScalaReflection._ + import org.apache.spark.sql.catalyst.ScalaReflection._ test("primitive data") { val schema = schemaFor[PrimitiveData] @@ -257,9 +256,9 @@ class ScalaReflectionSuite extends SparkFunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.createToCatalystConverter(dataType)(data) === convertedData) } test("convert Option[Product] to catalyst") { @@ -267,9 +266,9 @@ class ScalaReflectionSuite extends SparkFunSuite { val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), Some(primitiveData)) val dataType = schemaFor[OptionalData].dataType - val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, - Row(1, 1, 1, 1, 1, 1, true)) - assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) + val convertedData = InternalRow(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + InternalRow(1, 1, 1, 1, 1, 1, true)) + assert(CatalystTypeConverters.createToCatalystConverter(dataType)(data) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e09cd790a7187..77ca080f366cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "non-boolean filters", @@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", StringType)(exprId = ExprId(2)))) + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index dcb3635c5ccae..bc1537b0715b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.StringType @@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt - assertError(Sqrt('booleanField), "function sqrt accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } @@ -138,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), "WHEN expressions in CaseWhen should all be boolean type") + } + + test("check types for aggregates") { + // We will cast String to Double for sum and average + assertSuccess(Sum('stringField)) + assertSuccess(SumDistinct('stringField)) + assertSuccess(Average('stringField)) + + assertError(Min('complexField), "function min accepts non-complex type") + assertError(Max('complexField), "function max accepts non-complex type") + assertError(Sum('booleanField), "function sum accepts numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") + assertError(Average('booleanField), "function average accepts numeric type") + } + test("check types for others") { + assertError(CreateArray(Seq('intField, 'booleanField)), + "input to function array should all be the same type") + assertError(Coalesce(Seq('intField, 'booleanField)), + "input to function coalesce should all be the same type") + assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Explode('intField), + "input to function explode should be array or map type") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 9977f7af00f6b..f7b8e21bed490 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("type coercion for If") { + val rule = new HiveTypeCoercion { }.IfCoercion + ruleTest(rule, + If(Literal(true), Literal(1), Literal(1L)), + If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) + ) + + ruleTest(rule, + If(Literal.create(null, NullType), Literal(1), Literal(1)), + If(Literal.create(null, BooleanType), Literal(1), Literal(1)) + ) + } + test("type coercion for CaseKeyWhen") { val cwc = new HiveTypeCoercion {}.CaseWhenCoercion ruleTest(cwc, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e1afa81a7a82f..6c93698f8017b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,128 +17,145 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.Matchers._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.types.Decimal class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - test("arithmetic") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) - - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(-c1, -1, row) - checkEvaluation(c1 + c2, 3, row) - checkEvaluation(c1 - c2, -1, row) - checkEvaluation(c1 * c2, 2, row) - checkEvaluation(c1 / c2, 0, row) - checkEvaluation(c1 % c2, 1, row) + /** + * Runs through the testFunc for all numeric data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + testFunc(_.toFloat) + testFunc(_.toDouble) + testFunc(Decimal(_)) } - test("fractional arithmetic") { - val row = create_row(1.1, 2.0, 3.1, null) - val c1 = 'a.double.at(0) - val c2 = 'a.double.at(1) - val c3 = 'a.double.at(2) - val c4 = 'a.double.at(3) - - checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) - - checkEvaluation(-c1, -1.1, row) - checkEvaluation(c1 + c2, 3.1, row) - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + test("+ (Add)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Add(left, right), convert(3)) + checkEvaluation(Add(Literal.create(null, left.dataType), right), null) + checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) + } } - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) + test("- (UnaryMinus)") { + testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType + checkEvaluation(UnaryMinus(input), convert(-1)) + checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) + } } - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) + test("- (Minus)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Subtract(left, right), convert(-1)) + checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) + checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) + } } - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) + test("* (Multiply)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Multiply(left, right), convert(2)) + checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) + checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) + } + } - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + test("/ (Divide) basic") { + testNumericDataTypes { convert => + val left = Literal(convert(2)) + val right = Literal(convert(1)) + val dataType = left.dataType + checkEvaluation(Divide(left, right), convert(2)) + checkEvaluation(Divide(Literal.create(null, dataType), right), null) + checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero + } } - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) + test("/ (Divide) for integral type") { + checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) + checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) + checkEvaluation(Divide(Literal(1), Literal(2)), 0) + checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + } - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + test("/ (Divide) for floating point") { + checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) } - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) + test("% (Remainder)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Remainder(left, right), convert(1)) + checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) + checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + } + + test("Abs") { + testNumericDataTypes { convert => + checkEvaluation(Abs(Literal(convert(0))), convert(0)) + checkEvaluation(Abs(Literal(convert(1))), convert(1)) + checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + } + } - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) + test("MaxOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MaxOf(small, large), convert(2)) + checkEvaluation(MaxOf(large, small), convert(2)) + checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) } + } + + test("MaxOf for atomic type") { + checkEvaluation(MaxOf(true, false), true) + checkEvaluation(MaxOf("abc", "bcd"), "bcd") + checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 3.toByte)) + } + + test("MinOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MinOf(small, large), convert(1)) + checkEvaluation(MinOf(large, small), convert(1)) + checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) + } + } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) + test("MinOf for atomic type") { + checkEvaluation(MinOf(true, false), false) + checkEvaluation(MinOf("abc", "bcd"), "abc") + checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 2.toByte)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5bc7c30eee1b6..f3809be722a84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +43,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int") { checkCast(0, false) checkCast(1, true) - checkCast(5, true) + checkCast(-5, true) checkCast(1, 1.toByte) checkCast(1, 1.toShort) checkCast(1, 1) @@ -60,7 +61,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from long") { checkCast(0L, false) checkCast(1L, true) - checkCast(5L, true) + checkCast(-5L, true) checkCast(1L, 1.toByte) checkCast(1L, 1.toShort) checkCast(1L, 1) @@ -98,10 +99,28 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from float") { - + checkCast(0.0f, false) + checkCast(0.5f, true) + checkCast(-5.0f, true) + checkCast(1.5f, 1.toByte) + checkCast(1.5f, 1.toShort) + checkCast(1.5f, 1) + checkCast(1.5f, 1.toLong) + checkCast(1.5f, 1.5) + checkCast(1.5f, "1.5") } test("cast from double") { + checkCast(0.0, false) + checkCast(0.5, true) + checkCast(-5.0, true) + checkCast(1.5, 1.toByte) + checkCast(1.5, 1.toShort) + checkCast(1.5, 1) + checkCast(1.5, 1.toLong) + checkCast(1.5, 1.5f) + checkCast(1.5, "1.5") + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } @@ -137,7 +156,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) - checkEvaluation(cast(cast(ts, StringType), TimestampType), ts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) // all convert to string type to check checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) @@ -182,6 +201,19 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) } + test("from decimal") { + checkCast(Decimal(0.0), false) + checkCast(Decimal(0.5), true) + checkCast(Decimal(-5.0), true) + checkCast(Decimal(1.5), 1.toByte) + checkCast(Decimal(1.5), 1.toShort) + checkCast(Decimal(1.5), 1) + checkCast(Decimal(1.5), 1.toLong) + checkCast(Decimal(1.5), 1.5f) + checkCast(Decimal(1.5), 1.5) + checkCast(Decimal(1.5), "1.5") + } + test("casting to fixed-precision decimals") { // Overflow and rounding for casting to fixed-precision decimals: // - Values should round with HALF_UP mode by default when you lower scale @@ -269,9 +301,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.002f) checkEvaluation(cast(ts, DoubleType), 15.002) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(cast(cast(tss, LongType), TimestampType), ts) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) @@ -283,7 +316,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal(1)) // A test for higher precision than millis - checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001) + checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) checkEvaluation(cast(Double.NaN, TimestampType), null) checkEvaluation(cast(1.0 / 0.0, TimestampType), null) @@ -405,14 +438,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from struct") { val struct = Literal.create( - Row("123", "abc", "", null), + InternalRow("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal.create( - Row("123", "abc", ""), + InternalRow("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -425,7 +458,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", IntegerType, nullable = true), StructField("d", IntegerType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, Row(123, null, null, null)) + checkEvaluation(ret, InternalRow(123, null, null, null)) } { val ret = cast(struct, StructType(Seq( @@ -442,7 +475,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, Row(true, true, false, null)) + checkEvaluation(ret, InternalRow(true, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -459,7 +492,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", IntegerType, nullable = true), StructField("c", IntegerType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, Row(123, null, null)) + checkEvaluation(ret, InternalRow(123, null, null)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -474,7 +507,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, Row(true, true, false)) + checkEvaluation(ret, InternalRow(true, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -482,7 +515,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, Row(true, true, false)) + checkEvaluation(ret, InternalRow(true, true, false)) } { @@ -500,10 +533,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - Row( + InternalRow( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), - Row(0)), + InternalRow(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -523,10 +556,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, Row( + checkEvaluation(ret, InternalRow( Seq(123, null, null), Map("a" -> true, "b" -> true, "c" -> false), - Row(0L))) + InternalRow(0L))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index f151dd2a47f78..b80911e7257fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -21,15 +21,36 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + /** + * Runs through the testFunc for all integral data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + } + + test("GetArrayItem") { + testIntegralDataTypes { convert => + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) + checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") + } + } + test("CreateStruct") { - val row = Row(1, 2, 3) + val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0).as("a") val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row) } test("complex type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 152c4e4111244..372848ea9a596 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, BooleanType} +import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("if") { + val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)]( + (true, 1, 2, 1), + (false, 1, 2, 2), + (null, 1, 2, 2), + (true, null, 2, null), + (false, 1, null, null), + (null, null, 2, 2), + (null, 1, null, null) + ) + + // dataType must match T. + def testIf(convert: (Integer => Any), dataType: DataType): Unit = { + for ((predicate, trueValue, falseValue, expected) <- testcases) { + val trueValueConverted = if (trueValue == null) null else convert(trueValue) + val falseValueConverted = if (falseValue == null) null else convert(falseValue) + val expectedConverted = if (expected == null) null else convert(expected) + + checkEvaluation( + If(Literal.create(predicate, BooleanType), + Literal.create(trueValueConverted, dataType), + Literal.create(falseValueConverted, dataType)), + expectedConverted) + } + } + + testIf(_ == 1, BooleanType) + testIf(_.toShort, ShortType) + testIf(identity, IntegerType) + testIf(_.toLong, LongType) + + testIf(_.toFloat, FloatType) + testIf(_.toDouble, DoubleType) + testIf(Decimal(_), DecimalType.Unlimited) + + testIf(identity, DateType) + testIf(_.toLong, TimestampType) + + testIf(_.toString, StringType) + } + test("case when") { val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 87a92b87962f8..7d95ef7f710af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,6 +23,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. @@ -30,29 +32,43 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, Ge trait ExpressionEvalHelper { self: SparkFunSuite => - protected def create_row(values: Any*): Row = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + protected def create_row(values: Any*): InternalRow = { + InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } protected def checkEvaluation( - expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) - checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expression, catalystValue, inputRow) } - protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte]. + */ + protected def checkResult(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case _ => result == expected + } + } + + protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.eval(inputRow) } protected def checkEvaluationWithoutCodegen( expression: Expression, expected: Any, - inputRow: Row = EmptyRow): Unit = { + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -63,7 +79,7 @@ trait ExpressionEvalHelper { protected def checkEvaluationWithGeneratedMutableProjection( expression: Expression, expected: Any, - inputRow: Row = EmptyRow): Unit = { + inputRow: InternalRow = EmptyRow): Unit = { val plan = try { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() @@ -80,7 +96,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -89,7 +105,7 @@ trait ExpressionEvalHelper { protected def checkEvaluationWithGeneratedProjection( expression: Expression, expected: Any, - inputRow: Row = EmptyRow): Unit = { + inputRow: InternalRow = EmptyRow): Unit = { val ctx = GenerateProjection.newCodeGenContext() lazy val evaluated = expression.gen(ctx) @@ -106,7 +122,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + val expectedRow = InternalRow(expected) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" @@ -122,10 +138,19 @@ trait ExpressionEvalHelper { } } + protected def checkEvaluationWithOptimization( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) + } + protected def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], - inputRow: Row = EmptyRow): Unit = { + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f44f55dfb92d1..d924ff7a102f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,12 +18,26 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - // TODO: Add tests for all data types. + test("null") { + checkEvaluation(Literal.create(null, BooleanType), null) + checkEvaluation(Literal.create(null, ByteType), null) + checkEvaluation(Literal.create(null, ShortType), null) + checkEvaluation(Literal.create(null, IntegerType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, FloatType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, BinaryType), null) + checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) + checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) + } test("boolean literals") { checkEvaluation(Literal(true), true) @@ -31,25 +45,52 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("int literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(0L), 0L) + List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) + } + checkEvaluation(Literal(Long.MinValue), Long.MinValue) + checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { - List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toFloat), d.toFloat) - } + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) } + checkEvaluation(Literal(Double.MinValue), Double.MinValue) + checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal(Float.MinValue), Float.MinValue) + checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + } test("string literals") { + checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal("\0"), "\0") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) } + + test("binary literals") { + checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + } + + test("decimal") { + List(0.0, 1.2, 1.1111, 5).foreach { d => + checkEvaluation(Literal(Decimal(d)), Decimal(d)) + checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), + Decimal((d * 1000L).toLong, 10, 1)) + } + } + + // TODO(davies): add tests for ArrayType, MapType and StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 25ebc70d095d8..b932d4ab850c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -18,24 +18,41 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + /** + * Used for testing leaf math expressions. + * + * @param e expression + * @param c The constants in scala.math + * @tparam T Generic type for primitives + */ + private def testLeaf[T]( + e: () => Expression, + c: T): Unit = { + checkEvaluation(e(), c, EmptyRow) + checkEvaluation(e(), c, create_row(null)) + } + /** * Used for testing unary math expressions. * * @param c expression - * @param f The functions in scala.math + * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with * @param expectNull Whether the given values should return null or not * @tparam T Generic type for primitives + * @tparam U Generic type for the output of the given function `f` */ - private def testUnary[T]( + private def testUnary[T, U]( c: Expression => Expression, - f: T => T, + f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, + evalType: DataType = DoubleType): Unit = { if (expectNull) { domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) @@ -45,7 +62,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(value)), f(value), EmptyRow) } } - checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null)) } /** @@ -74,6 +91,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + test("e") { + testLeaf(EulerNumber, math.E) + } + + test("pi") { + testLeaf(Pi, math.Pi) + } + test("sin") { testUnary(Sin, math.sin) } @@ -145,7 +170,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("signum") { - testUnary[Double](Signum, math.signum) + testUnary[Double, Double](Signum, math.signum) } test("log") { @@ -163,11 +188,56 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) } + test("bin") { + testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) + + val row = create_row(null, 12L, 123L, 1234L, -123L) + val l1 = 'a.long.at(0) + val l2 = 'a.long.at(1) + val l3 = 'a.long.at(2) + val l4 = 'a.long.at(3) + val l5 = 'a.long.at(4) + + checkEvaluation(Bin(l1), null, row) + checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) + checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) + checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) + checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) + } + + test("log2") { + def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) + testUnary(Log2, f, (0 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + } + + test("sqrt") { + testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) + checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + } + test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("hex") { + checkEvaluation(Hex(Literal(28)), "1C") + checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + // scalastyle:on + } + test("hypot") { testBinary(Hypot, math.hypot) } @@ -176,4 +246,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Atan2, math.atan2) } + test("binary log") { + val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) + val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) + + domain.foreach { case (v1, v2) => + checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) + } + checkEvaluation( + Logarithm(Literal.create(null, DoubleType), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal.create(null, DoubleType)), + null, + create_row(null)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala new file mode 100644 index 0000000000000..36e636b5da6b8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} + +class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("md5") { + checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "6ac1e56bc78f031059be7be854522c4c") + checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + } + + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index b6261bfba0786..72fec3b86e5e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -23,7 +23,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{IntegerType, BooleanType} @@ -167,8 +167,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02")) checkEvaluation(Literal(d1) < Literal(d2), true) val ts1 = new Timestamp(12) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala new file mode 100644 index 0000000000000..9be2b23a53f27 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DoubleType, IntegerType} + + +class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("random") { + val row = create_row(1.1, 2.0, 3.1, null) + checkDoubleEvaluation(Rand(30), (0.7363714192755834 +- 0.001), row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 2e81296c4e623..5dbb1d562c1d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -215,4 +215,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluate("abbbbc" rlike regEx, create_row("**")) } } + + test("length for string") { + val regEx = 'a.string.at(0) + checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) + checkEvaluation(StringLength(regEx), 5, create_row("abdef")) + checkEvaluation(StringLength(regEx), 0, create_row("")) + checkEvaluation(StringLength(regEx), null, create_row(null)) + checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 88a36aa121b55..6fafc2f86684c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,22 +20,25 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random -import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String + class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { - import UnsafeFixedWidthAggregationMap._ - private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + private def emptyProjection: Projection = + GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -50,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite } } - test("supported schemas") { - assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - - assert( - !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - } - test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -75,14 +68,14 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics ) - val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + val groupKey = InternalRow(UTF8String.fromString("cats")) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) map.getAggregationBuffer(groupKey) @@ -101,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 128, // initial capacity false // disable perf metrics @@ -111,13 +104,43 @@ class UnsafeFixedWidthAggregationMapSuite val rand = new Random(42) val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet groupKeys.foreach { keyString => - map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) } val seenKeys: Set[String] = map.iterator().asScala.map { entry => entry.key.getString(0) }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() + } + + test("with decimal in the key and values") { + val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) + val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) + val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), + Seq(AttributeReference("price", DecimalType.Unlimited)())) + val map = new UnsafeFixedWidthAggregationMap( + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), + memoryManager, + 1, // initial capacity + false // disable perf metrics + ) + + (0 until 100).foreach { i => + val groupKey = InternalRow(Decimal(i % 10)) + val row = map.getAggregationBuffer(groupKey) + row.update(0, Decimal(i)) + } + val seenKeys: Set[Int] = map.iterator().asScala.map { entry => + entry.key.getAs[Decimal](0).toInt + }.toSet + seenKeys.size should be (10) + seenKeys should be ((0 until 10).toSet) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 61722f1ffa462..94c2f3242b122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { @@ -38,40 +41,124 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (3 * 8)) + assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getLong(1) should be (1) - unsafeRow.getInt(2) should be (2) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getLong(1) === 1) + assert(unsafeRow.getInt(2) === 2) + + unsafeRow.setLong(1, 3) + assert(unsafeRow.getLong(1) === 3) + unsafeRow.setInt(2, 4) + assert(unsafeRow.getInt(2) === 4) + } + + test("basic conversion with primitive, string and binary types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, UTF8String.fromString("Hello")) + row.update(2, "World".getBytes) + + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) + + val unsafeRow = new UnsafeRow() + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.get(2) === "World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + assert(unsafeRow.getString(1) === "World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + assert(unsafeRow.getString(1) === "Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + assert(unsafeRow.get(2) === "World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + assert(unsafeRow.get(2) === "Hello World".getBytes) + assert(pool.size === 2) + } + + test("basic conversion with primitive, decimal and array") { + val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, Decimal(1)) + row.update(2, Array(2)) + + val pool = new ObjectPool(10) + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + assert(numBytesWritten === sizeRequired) + assert(pool.size === 2) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.get(1) === Decimal(1)) + assert(unsafeRow.get(2) === Array(2)) + + unsafeRow.update(1, Decimal(2)) + assert(unsafeRow.get(1) === Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + assert(unsafeRow.get(2) === Array(3, 4)) + assert(pool.size === 2) } - test("basic conversion with primitive and string types") { - val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + test("basic conversion with primitive, string, date and timestamp types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = new UnsafeRowConverter(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") - row.setString(2, "World") + row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) + row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + assert(sizeRequired === 8 + (8 * 4) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") - unsafeRow.getString(2) should be ("World") + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + // Date is represented as Int in unsafeRow + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) + // Timestamp is represented as Long in unsafeRow + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) + unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -83,10 +170,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + StringType, + BinaryType, + DecimalType.Unlimited, + ArrayType(IntegerType) + ) val converter = new UnsafeRowConverter(fieldTypes) - val rowWithAllNullColumns: Row = { + val rowWithAllNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) for (i <- 0 to fieldTypes.length - 1) { r.setNullAt(i) @@ -97,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( @@ -106,18 +198,22 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } - createdFromNull.getBoolean(1) should be (false) - createdFromNull.getByte(2) should be (0) - createdFromNull.getShort(3) should be (0) - createdFromNull.getInt(4) should be (0) - createdFromNull.getLong(5) should be (0) + assert(createdFromNull.getBoolean(1) === false) + assert(createdFromNull.getByte(2) === 0) + assert(createdFromNull.getShort(3) === 0) + assert(createdFromNull.getInt(4) === 0) + assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getString(8) === null) + assert(createdFromNull.get(9) === null) + assert(createdFromNull.get(10) === null) + assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by // creating an entirely null row via the converter - val rowWithNoNullColumns: Row = { + val rowWithNoNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) r.setNullAt(0) r.setBoolean(1, false) @@ -127,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setLong(5, 500) r.setFloat(6, 600) r.setDouble(7, 700) + r.update(8, UTF8String.fromString("hello")) + r.update(9, "world".getBytes) + r.update(10, Decimal(10)) + r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val pool = new ObjectPool(1) + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) - setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) - setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) - setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) - setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) - setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) - setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) - setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) - setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- 0 to fieldTypes.length - 1) { + if (i >= 8) { + setToNullAfterCreation.update(i, null) + } setToNullAfterCreation.setNullAt(i) } - assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + // There are some garbage left in the var-length area + assert(Arrays.equals(createdFromNullBuffer, + java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + + setToNullAfterCreation.setNullAt(0) + setToNullAfterCreation.setBoolean(1, false) + setToNullAfterCreation.setByte(2, 20) + setToNullAfterCreation.setShort(3, 30) + setToNullAfterCreation.setInt(4, 400) + setToNullAfterCreation.setLong(5, 500) + setToNullAfterCreation.setFloat(6, 600) + setToNullAfterCreation.setDouble(7, 700) + setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.update(11, Array(11)) + + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 6841bd9890c97..54e8c6462e962 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,13 +37,11 @@ class ConvertToLocalRelationSuite extends PlanTest { test("Project on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( LocalRelation('a.int, 'b.int).output, - Row(1, 2) :: - Row(4, 5) :: Nil) + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( LocalRelation('a1.int, 'b1.int).output, - Row(1, 3) :: - Row(4, 6) :: Nil) + InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) val projectOnLocal = testRelation.select( UnresolvedAttribute("a").as("a1"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 17dc9124749e8..ffdc673cdc455 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -36,6 +36,7 @@ class FilterPushdownSuite extends PlanTest { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, + BooleanSimplification, PushPredicateThroughJoin, PushPredicateThroughGenerate, ColumnPruning, @@ -156,11 +157,9 @@ class FilterPushdownSuite extends PlanTest { .where('a === 1 && 'a === 2) .select('a).analyze - comparePlans(optimized, correctAnswer) } - test("joins: push to either side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -198,6 +197,25 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: push to one side after transformCondition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.a".attr === 1 && "y.d".attr === "x.b".attr) || + ("x.a".attr === 1 && "y.d".attr === "x.c".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a === 1) + val right = testRelation1 + val correctAnswer = + left.join(right, condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + test("joins: rewrite filter to push to either side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -563,17 +581,16 @@ class FilterPushdownSuite extends PlanTest { // push down invalid val originalQuery1 = { x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) + .sortBy(SortOrder('a, Ascending)) + .select('b) } val optimized1 = Optimize.execute(originalQuery1.analyze) val correctAnswer1 = x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index 35f50be46b76f..ec379489a6d1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class UnionPushdownSuite extends PlanTest { +class UnionPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - UnionPushdown) :: Nil + UnionPushDown) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 67db3d5e6d751..bda217935cb05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -28,7 +28,12 @@ case class Dummy(optKey: Option[Expression]) extends Expression { override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true - override def eval(input: Row): Any = null.asInstanceOf[Any] + override def eval(input: InternalRow): Any = null.asInstanceOf[Any] +} + +case class ComplexPlan(exprs: Seq[Seq[Expression]]) + extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { + override def output: Seq[Attribute] = Nil } class TreeNodeSuite extends SparkFunSuite { @@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite { assert(expected === actual) } } + + test("transformExpressions on nested expression sequence") { + val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2)))) + val actual = plan.transformExpressions { + case Literal(value, _) => Literal(value.toString) + } + val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) + assert(expected === actual) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala new file mode 100644 index 0000000000000..03eb64f097a37 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.Timestamp + +import org.apache.spark.SparkFunSuite + +class DateTimeUtilsSuite extends SparkFunSuite { + + test("timestamp and 100ns") { + val now = new Timestamp(System.currentTimeMillis()) + now.setNanos(100) + val ns = DateTimeUtils.fromJavaTimestamp(now) + assert(ns % 10000000L === 1) + assert(DateTimeUtils.toJavaTimestamp(ns) === now) + + List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => + val ts = DateTimeUtils.toJavaTimestamp(t) + assert(DateTimeUtils.fromJavaTimestamp(ts) === t) + assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + } + } + + test("100ns and julian day") { + val (d, ns) = DateTimeUtils.toJulianDay(0) + assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) + assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) + assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + + val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) + val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) + val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + assert(t.equals(t2)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala new file mode 100644 index 0000000000000..94764df4b9cdb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class ObjectPoolSuite extends SparkFunSuite with Matchers { + + test("pool") { + val pool = new ObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(false) === 2) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.get(2) === false) + assert(pool.size() === 3) + + pool.replace(1, "world") + assert(pool.get(1) === "world") + assert(pool.size() === 3) + } + + test("unique pool") { + val pool = new UniqueObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.size() === 2) + + intercept[UnsupportedOperationException] { + pool.replace(1, "world") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 261c4fcad24aa..14e7b4a9561b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -33,6 +33,37 @@ class DataTypeSuite extends SparkFunSuite { assert(MapType(StringType, IntegerType, true) === map) } + test("construct with add") { + val struct = (new StructType) + .add("a", IntegerType, true) + .add("b", LongType, false) + .add("c", StringType, true) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with add from StructField") { + // Test creation from StructField type + val struct = (new StructType) + .add(StructField("a", IntegerType, true)) + .add(StructField("b", LongType, false)) + .add(StructField("c", StringType, true)) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with String DataType") { + // Test creation with DataType as String + val struct = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + + assert(StructField("a", IntegerType, true) === struct("a")) + assert(StructField("b", LongType, false) === struct("b")) + assert(StructField("c", StringType, true) === struct("c")) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: @@ -190,7 +221,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType, 12) + checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala deleted file mode 100644 index 81d7ab010f394..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.types - -import org.apache.spark.SparkFunSuite - -// scalastyle:off -class UTF8StringSuite extends SparkFunSuite { - test("basic") { - def check(str: String, len: Int) { - - assert(UTF8String(str).length == len) - assert(UTF8String(str.getBytes("utf8")).length() == len) - - assert(UTF8String(str) == str) - assert(UTF8String(str.getBytes("utf8")) == str) - assert(UTF8String(str).toString == str) - assert(UTF8String(str.getBytes("utf8")).toString == str) - assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) - - assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) - } - - check("hello", 5) - check("世 界", 3) - } - - test("contains") { - assert(UTF8String("hello").contains(UTF8String("ello"))) - assert(!UTF8String("hello").contains(UTF8String("vello"))) - assert(UTF8String("大千世界").contains(UTF8String("千世"))) - assert(!UTF8String("大千世界").contains(UTF8String("世千"))) - } - - test("prefix") { - assert(UTF8String("hello").startsWith(UTF8String("hell"))) - assert(!UTF8String("hello").startsWith(UTF8String("ell"))) - assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) - assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) - } - - test("suffix") { - assert(UTF8String("hello").endsWith(UTF8String("ello"))) - assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) - assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) - assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) - } - - test("slice") { - assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) - assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) - assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) - assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 28b373e258311..5f312964e5bf7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -156,4 +156,20 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(-100) % Decimal(3) === Decimal(-1)) assert(Decimal(100) % Decimal(0) === null) } + + test("set/setOrNull") { + assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L) + assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) + assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) + } + + test("accurate precision after multiplication") { + val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal + assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") + } + + test("fix non-terminating decimal expansion problem") { + val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) + assert(decimal.toString === "0.333") + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ed75475a87067..8fc16928adbd9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - org.jodd - jodd-core - ${jodd.version} - junit junit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d3efa83380d04..f201c8ea8a110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ @@ -621,7 +620,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @scala.annotation.varargs - def in(list: Column*): Column = In(expr, list.map(_.expr)) + def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4a224153e1a37..8fe1f7e34cb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties -import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -33,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -169,23 +168,35 @@ class DataFrame private[sql]( /** * Internal API for Python - * @param numRows Number of rows to show + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(numRows: Int): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) val sb = new StringBuilder - val data = take(numRows) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) val numCols = schema.fieldNames.length + // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => - val str = if (cell == null) "null" else cell.toString - if (str.length > 20) str.substring(0, 17) + "..." else str + val str = cell match { + case null => "null" + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str }: Seq[String] } + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + // Compute the width of each column - val colWidths = Array.fill(numCols)(0) for (row <- rows) { for ((cell, i) <- row.zipWithIndex) { colWidths(i) = math.max(colWidths(i), cell.length) @@ -197,7 +208,11 @@ class DataFrame private[sql]( // column names rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") sb.append(sep) @@ -205,11 +220,22 @@ class DataFrame private[sql]( // data rows.tail.map { _.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") } sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows ${rowsString}\n") + } + sb.toString() } @@ -314,7 +340,8 @@ class DataFrame private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be + * truncated, and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -328,15 +355,46 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = println(showString(numRows)) + def show(numRows: Int): Unit = show(numRows, true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. * @group action * @since 1.3.0 */ def show(): Unit = show(20) + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + /** * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ @@ -395,22 +453,50 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumn: String): DataFrame = { + join(right, Seq(usingColumn)) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the columns "user_id" and "user_name" + * df1.join(df2, Seq("user_id", "user_name")) + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @group dfops + * @since 1.4.0 + */ + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] - // Project only one of the join column. - val joinedCol = joined.right.resolve(usingColumn) + // Project only one of the join columns. + val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val condition = usingColumns.map { col => + catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => + catalyst.expressions.And(cond, eqTo) + } + Project( - joined.output.filterNot(_ == joinedCol), + joined.output.filterNot(joinedCols.contains(_)), Join( joined.left, joined.right, joinType = Inner, - Some(catalyst.expressions.EqualTo( - joined.left.resolve(usingColumn), - joined.right.resolve(usingColumn)))) + condition) ) } @@ -574,7 +660,7 @@ class DataFrame private[sql]( def as(alias: Symbol): DataFrame = as(alias.name) /** - * Selects a set of expressions. + * Selects a set of column based expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} @@ -584,6 +670,10 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) case Column(expr: NamedExpression) => expr // Leave an unaliased explode with an empty list of names since the analzyer will generate the // correct defaults after the nested expression's type has been resolved. @@ -633,7 +723,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 @@ -658,13 +747,24 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 */ def where(condition: Column): DataFrame = filter(condition) + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDf.where("age > 15") + * }}} + * @group dfops + * @since 1.5.0 + */ + def where(conditionExpr: String): DataFrame = { + filter(Column(new SqlParser().parseExpression(conditionExpr))) + } + /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. @@ -984,9 +1084,10 @@ class DataFrame private[sql]( val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } val names = schema.toAttributes.map(_.name) + val convert = CatalystTypeConverters.createToCatalystConverter(schema) val rowFunction = - f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, @@ -999,7 +1100,7 @@ class DataFrame private[sql]( * columns of the input row are implicitly joined with each value that is output by the function. * * {{{ - * df.explode("words", "word")(words: String => words.split(" ")) + * df.explode("words", "word"){words: String => words.split(" ")} * }}} * @group dfops * @since 1.3.0 @@ -1012,8 +1113,9 @@ class DataFrame private[sql]( val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } val names = attributes.map(_.name) - def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) + def rowFunction(row: Row): TraversableOnce[InternalRow] = { + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) @@ -1175,7 +1277,7 @@ class DataFrame private[sql]( val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val ret: Seq[Row] = if (outputCols.nonEmpty) { + val ret: Seq[InternalRow] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } @@ -1184,11 +1286,12 @@ class DataFrame private[sql]( // Pivot the data so each summary is one row row.grouped(outputCols.size).toSeq.zip(statistics).map { - case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + case (aggregation, (statistic, _)) => + InternalRow(statistic :: aggregation.toList: _*) } } else { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } + statistics.map { case (name, _) => InternalRow(name) } } // All columns are string type @@ -1366,12 +1469,14 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().mapPartitions { rows => + internalRowRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } + private[sql] def internalRowRdd = queryExecution.executedPlan.execute() + /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b44d4c86ac5d3..1828ed1aab50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -245,7 +245,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) - sqlContext.createDataFrame(rowRDD, appliedSchema, needsConversion = false) + sqlContext.internalCreateDataFrame(rowRDD, appliedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc627d5..99d557b03a033 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,27 +70,31 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +116,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +170,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +224,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 77c6af27d1007..9a10a23937fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,73 +22,360 @@ import java.util.Properties import scala.collection.immutable import scala.collection.JavaConversions._ +import org.apache.parquet.hadoop.ParquetOutputCommitter + import org.apache.spark.sql.catalyst.CatalystConf private[spark] object SQLConf { - val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" - val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" - val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" - val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val CODEGEN_ENABLED = "spark.sql.codegen" - val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" - val DIALECT = "spark.sql.dialect" - val CASE_SENSITIVE = "spark.sql.caseSensitive" - - val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" - val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" - val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" - val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" - val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" - val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" - - val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown" - - val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" - - val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" - val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" + + private val sqlConfEntries = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, SQLConfEntry[_]]()) + + /** + * An entry contains all meta information for a configuration. + * + * @param key the key for the configuration + * @param defaultValue the default value for the configuration + * @param valueConverter how to convert a string to the value. It should throw an exception if the + * string does not have the required format. + * @param stringConverter how to convert a value to a string that the user can use it as a valid + * string value. It's usually `toString`. But sometimes, a custom converter + * is necessary. E.g., if T is List[String], `a, b, c` is better than + * `List(a, b, c)`. + * @param doc the document for the configuration + * @param isPublic if this configuration is public to the user. If it's `false`, this + * configuration is only used internally and we should not expose it to the user. + * @tparam T the value type + */ + private[sql] class SQLConfEntry[T] private( + val key: String, + val defaultValue: Option[T], + val valueConverter: String => T, + val stringConverter: T => String, + val doc: String, + val isPublic: Boolean) { + + def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("") + + override def toString: String = { + s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)" + } + } + + private[sql] object SQLConfEntry { + + private def apply[T]( + key: String, + defaultValue: Option[T], + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean): SQLConfEntry[T] = + sqlConfEntries.synchronized { + if (sqlConfEntries.containsKey(key)) { + throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered") + } + val entry = + new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic) + sqlConfEntries.put(key, entry) + entry + } + + def intConf( + key: String, + defaultValue: Option[Int] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Int] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toInt + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be int, but was $v") + } + }, _.toString, doc, isPublic) + + def longConf( + key: String, + defaultValue: Option[Long] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Long] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toLong + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be long, but was $v") + } + }, _.toString, doc, isPublic) + + def doubleConf( + key: String, + defaultValue: Option[Double] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Double] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toDouble + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be double, but was $v") + } + }, _.toString, doc, isPublic) + + def booleanConf( + key: String, + defaultValue: Option[Boolean] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Boolean] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $v") + } + }, _.toString, doc, isPublic) + + def stringConf( + key: String, + defaultValue: Option[String] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[String] = + SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic) + + def enumConf[T]( + key: String, + valueConverter: String => T, + validValues: Set[T], + defaultValue: Option[T] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[T] = + SQLConfEntry(key, defaultValue, v => { + val _v = valueConverter(v) + if (!validValues.contains(_v)) { + throw new IllegalArgumentException( + s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v") + } + _v + }, _.toString, doc, isPublic) + + def seqConf[T]( + key: String, + valueConverter: String => T, + defaultValue: Option[Seq[T]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[T]] = { + SQLConfEntry( + key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic) + } + + def stringSeqConf( + key: String, + defaultValue: Option[Seq[String]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[String]] = { + seqConf(key, s => s, defaultValue, doc, isPublic) + } + } + + import SQLConfEntry._ + + val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", + defaultValue = Some(true), + doc = "When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + + val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", + defaultValue = Some(10000), + doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + + val IN_MEMORY_PARTITION_PRUNING = + booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", + defaultValue = Some(false), + doc = "") + + val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", + defaultValue = Some(10 * 1024 * 1024), + doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + + "Note that currently statistics are only supported for Hive Metastore tables where the " + + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") + + val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false) + + val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", + defaultValue = Some(200), + doc = "Configures the number of partitions to use when shuffling data for joins or " + + "aggregations.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), + doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + + " a specific query. For some queries with complicated expression this option can lead to " + + "significant speed-ups. However, for simple queries this can actually slow down query " + + "execution.") + + val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", + defaultValue = Some(false), + doc = "") + + val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "") + + val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", + defaultValue = Some(true), + doc = "") + + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", + defaultValue = Some(false), + doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + + "Spark SQL, do not differentiate between binary data and strings when writing out the " + + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + + "compatibility with these systems.") + + val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp", + defaultValue = Some(true), + doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + + "provide compatibility with these systems.") + + val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata", + defaultValue = Some(true), + doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + + val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec", + valueConverter = v => v.toLowerCase, + validValues = Set("uncompressed", "snappy", "gzip", "lzo"), + defaultValue = Some("gzip"), + doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " + + "uncompressed, snappy, gzip, lzo.") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", + defaultValue = Some(false), + doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + + "because of a known bug in Parquet 1.6.0rc3 " + + "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + + "if your table doesn't contain any nullable string or binary columns, it's still safe to " + + "turn this feature on.") + + val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", + defaultValue = Some(true), + doc = "") + + val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( + key = "spark.sql.parquet.followParquetFormatSpec", + defaultValue = Some(false), + doc = "Wether to stick to Parquet format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + + "to compatible mode if set to false.", + isPublic = false) + + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( + key = "spark.sql.parquet.output.committer.class", + defaultValue = Some(classOf[ParquetOutputCommitter].getName), + doc = "The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + + "option must be set in Hadoop Configuration. 2. This option overrides " + + "\"spark.sql.sources.outputCommitterClass\"." + ) + + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", + defaultValue = Some(false), + doc = "") + + val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", + defaultValue = Some(true), + doc = "") + + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", + defaultValue = Some("_corrupt_record"), + doc = "") + + val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", + defaultValue = Some(5 * 60), + doc = "") // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = "spark.sql.planner.externalSort" - val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", + defaultValue = Some(true), + doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + + " memory.") + + val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", + defaultValue = Some(false), + doc = "") // This is only used for the thriftserver - val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" - val THRIFTSERVER_UI_STATEMENT_LIMIT = "spark.sql.thriftserver.ui.retainedStatements" - val THRIFTSERVER_UI_SESSION_LIMIT = "spark.sql.thriftserver.ui.retainedSessions" + val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", + doc = "Set a Fair Scheduler pool for a JDBC client session") + + val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", + defaultValue = Some(200), + doc = "") + + val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", + defaultValue = Some(200), + doc = "") // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" + val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", + defaultValue = Some("org.apache.spark.sql.parquet"), + doc = "") + // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", + defaultValue = Some(4000), + doc = "") // Whether to perform partition discovery when loading external data sources. Default to true. - val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", + defaultValue = Some(true), + doc = "") - // The output committer class used by FSBasedRelation. The specified class needs to be a + // Whether to perform partition column type inference. Default to true. + val PARTITION_COLUMN_TYPE_INFERENCE = + booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", + defaultValue = Some(true), + doc = "") + + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" + // + // NOTE: + // + // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. + // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". + val OUTPUT_COMMITTER_CLASS = + stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis", + defaultValue = Some(true), + doc = "") // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity" + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "") // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns" + val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns", + defaultValue = Some(true), + doc = "") - val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2", + defaultValue = Some(true), doc = "") - val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" + val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", + defaultValue = Some(true), doc = "") object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -127,60 +414,54 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Note that the choice of dialect does not affect things like what tables are available or * how query execution is performed. */ - private[spark] def dialect: String = getConf(DIALECT, "sql") + private[spark] def dialect: String = getConf(DIALECT) /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + + private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) /** When true predicates will be passed to the parquet record reader when possible. */ - private[spark] def parquetFilterPushDown = - getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) /** When true uses Parquet implementation based on data source API */ - private[spark] def parquetUseDataSourceApi = - getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) - private[spark] def orcFilterPushDown = - getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) /** When true uses verifyPartitionPath to prune the path which is not exists. */ - private[spark] def verifyPartitionPath = - getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) /** When true the planner will use the external sort, which may spill to disk. */ - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) /** * Sort merge join would sort the two side of join first, and then iterate both sides together * only once to get all matches. Using sort merge join can save a lot of memory usage compared * to HashJoin. */ - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) /** - * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster - * than interpreted evaluation, but there are significant start-up costs due to compilation. - * As a result codegen is only beneficial when queries run for a long time, or when the same - * expressions are used multiple times. - * - * Defaults to false as this feature is currently experimental. + * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation. */ - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) /** * caseSensitive analysis true by default */ - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, "true").toBoolean + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) /** * When set to true, Spark SQL will use managed memory for certain operations. This option only @@ -188,15 +469,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) /** * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 */ - private[spark] def useJacksonStreamingAPI: Boolean = - getConf(USE_JACKSON_STREAMING_API, "true").toBoolean + private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -205,8 +485,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ - private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, @@ -215,79 +494,122 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * in joins. */ private[spark] def defaultSizeInBytes: Long = - getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) /** * When set to true, we always treat byte arrays in Parquet files as strings. */ - private[spark] def isParquetBinaryAsString: Boolean = - getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) /** * When set to true, we always treat INT96Values in Parquet files as timestamp. */ - private[spark] def isParquetINT96AsTimestamp: Boolean = - getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + + /** + * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL + * schema and vice versa. Otherwise, falls back to compatible mode. + */ + private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ - private[spark] def inMemoryPartitionPruning: Boolean = - getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - private[spark] def columnNameOfCorruptRecord: String = - getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) /** * Timeout in seconds for the broadcast wait time in hash join */ - private[spark] def broadcastTimeout: Int = - getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt + private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - private[spark] def defaultDataSourceName: String = - getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - private[spark] def partitionDiscoveryEnabled() = - getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean + private[spark] def partitionDiscoveryEnabled(): Boolean = + getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) + + private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = - getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt + private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - private[spark] def dataFrameEagerAnalysis: Boolean = - getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = - getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) - private[spark] def dataFrameRetainGroupColumns: Boolean = - getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean + private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => settings.put(k, v) } + props.foreach { case (k, v) => setConfString(k, v) } } - /** Set the given Spark SQL configuration property. */ - def setConf(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property using a `string` value. */ + def setConfString(key: String, value: String): Unit = { require(key != null, "key cannot be null") require(value != null, s"value cannot be null for key: $key") + val entry = sqlConfEntries.get(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(value) + } settings.put(key, value) } + /** Set the given Spark SQL configuration property. */ + def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + settings.put(entry.key, entry.stringConverter(value)) + } + /** Return the value of Spark SQL configuration property for the given key. */ - def getConf(key: String): String = { - Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) + def getConfString(key: String): String = { + Option(settings.get(key)). + orElse { + // Try to use the default value + Option(sqlConfEntries.get(key)).map(_.defaultValueString) + }. + getOrElse(throw new NoSuchElementException(key)) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. + * yet, return `defaultValue` in [[SQLConfEntry]]. */ - def getConf(key: String, defaultValue: String): String = { + def getConf[T](entry: SQLConfEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). + getOrElse(throw new NoSuchElementException(entry.key)) + } + + /** + * Return the `string` value of Spark SQL configuration property for the given key. If the key is + * not set yet, return `defaultValue`. + */ + def getConfString(key: String, defaultValue: String): String = { + val entry = sqlConfEntries.get(key) + if (entry != null && defaultValue != "") { + // Only verify configs in the SQLConf object + entry.valueConverter(defaultValue) + } Option(settings.get(key)).getOrElse(defaultValue) } @@ -297,11 +619,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } - private[spark] def unsetConf(key: String) { + /** + * Return all the configuration definitions that have been defined in [[SQLConf]]. Each + * definition contains key, defaultValue and doc. + */ + def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { + sqlConfEntries.values.filter(_.isPublic).map { entry => + (entry.key, entry.defaultValueString, entry.doc) + }.toSeq + } + + private[spark] def unsetConf(key: String): Unit = { settings -= key } - private[spark] def clear() { + private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { + settings -= entry.key + } + + private[spark] def clear(): Unit = { settings.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ddb54025baa24..fc14a77538ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,17 +31,18 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -80,13 +81,16 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def setConf(props: Properties): Unit = conf.setConf(props) + /** Set the given Spark SQL configuration property. */ + private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value) + /** * Set the given Spark SQL configuration property. * * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = conf.setConf(key, value) + def setConf(key: String, value: String): Unit = conf.setConfString(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -94,7 +98,22 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String): String = conf.getConf(key) + def getConf(key: String): String = conf.getConfString(key) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[SQLConfEntry]]. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + conf.getConf(entry, defaultValue) + } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set @@ -103,7 +122,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) + def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). @@ -120,13 +139,14 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) + protected[sql] lazy val functionRegistry: FunctionRegistry = + new OverrideFunctionRegistry(FunctionRegistry.builtin) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUdfs :: + ExtractPythonUDFs :: sources.PreInsertCastAndRename :: Nil @@ -237,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * The following example registers a Scala closure as UDF: * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) * }}} * * The following example registers a UDF in Java: @@ -358,10 +378,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setInt(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -374,10 +395,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setLong(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -389,11 +411,12 @@ class SQLContext(@transient val sparkContext: SparkContext) val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => - row.setString(0, v) - row: Row + row.update(0, UTF8String.fromString(v)) + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } } @@ -485,14 +508,26 @@ class SQLContext(@transient val sparkContext: SparkContext) // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[Row]) + rowRDD.map(converter(_).asInstanceOf[InternalRow]) } else { - rowRDD + rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + DataFrame(this, logicalPlan) + } + /** * :: DeveloperApi :: * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. @@ -524,13 +559,13 @@ class SQLContext(@transient val sparkContext: SparkContext) Class.forName(className, true, Utils.getContextOrSparkClassLoader)) val extractors = localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) - + val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => + (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + } iter.map { row => - new GenericRow( - extractors.zip(attributeSeq).map { case (e, attr) => - CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) - }.toArray[Any] - ) : Row + new GenericInternalRow( + methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any] + ): InternalRow } } DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) @@ -827,7 +862,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: @@ -885,7 +920,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) /** * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. @@ -952,7 +987,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = executedPlan.execute() + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -1034,7 +1069,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericRow(m): Row} + iter.map { m => new GenericInternalRow(m): InternalRow} } DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 305b306a79871..e59fa6e162900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -44,8 +44,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr private val pair: Parser[LogicalPlan] = (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None, output) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output) + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) } def apply(input: String): LogicalPlan = parseAll(pair, input) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3cc5c2441d8a5..03dc37aa73f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -95,7 +95,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { | functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) |}""".stripMargin) } */ @@ -126,7 +126,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -138,7 +138,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -150,7 +150,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -162,7 +162,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -174,7 +174,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -210,7 +210,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -222,7 +222,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -234,7 +234,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -246,7 +246,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -258,7 +258,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -270,7 +270,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -282,7 +282,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -294,7 +294,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -306,7 +306,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -318,7 +318,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -330,7 +330,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -366,7 +366,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -378,7 +378,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -390,7 +390,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -405,7 +405,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF1[_, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) } /** @@ -415,7 +415,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) } /** @@ -425,7 +425,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) } /** @@ -435,7 +435,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -445,7 +445,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -455,7 +455,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -465,7 +465,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -475,7 +475,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -485,7 +485,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -495,7 +495,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -505,7 +505,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -515,7 +515,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -525,7 +525,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -535,7 +535,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -545,7 +545,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -555,7 +555,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -575,7 +575,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -585,7 +585,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -595,7 +595,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -605,7 +605,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -615,7 +615,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index a02e202d2eebc..831eb7eb0fae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { def apply(exprs: Column*): Column = { - Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 64449b2659b4b..931469bed634a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) +private[sql] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) +private[sql] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + private[sql] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class GenericColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) with NullableColumnAccessor +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) @@ -118,17 +118,17 @@ private[sql] object ColumnAccessor { dup.getInt() dataType match { + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) case IntegerType => new IntColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) case LongType => new LongColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) case FloatType => new FloatColumnAccessor(dup) case DoubleType => new DoubleColumnAccessor(dup) - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) case StringType => new StringColumnAccessor(dup) case BinaryType => new BinaryColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) case _ => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index aa10af400c815..087c52239713d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.columnar.ColumnBuilder._ import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ @@ -33,7 +33,7 @@ private[sql] trait ColumnBuilder { /** * Appends `row(ordinal)` to the column builder. */ - def appendFrom(row: Row, ordinal: Int) + def appendFrom(row: InternalRow, ordinal: Int) /** * Column statistics information @@ -68,7 +68,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int): Unit = { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) columnType.append(row, ordinal, buffer) } @@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) private[sql] class FixedDecimalColumnBuilder( precision: Int, @@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder( new FixedDecimalColumnStats, FIXED_DECIMAL(precision, scale)) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +// TODO (lian) Add support for array, struct and map +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) - private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -151,17 +151,17 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder case IntegerType => new IntColumnBuilder + case DateType => new DateColumnBuilder case LongType => new LongColumnBuilder + case TimestampType => new TimestampColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DateType => new DateColumnBuilder - case TimestampType => new TimestampColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) case _ => new GenericColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b0f983c180673..00374d1fa3ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() @@ -54,7 +53,7 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int): Unit = { + def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (row.isNullAt(ordinal)) { nullCount += 1 // 4 bytes for null position @@ -67,23 +66,23 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: Row + def collectedStatistics: InternalRow } /** * A no-op ColumnStats only used for testing purposes. */ private[sql] class NoopColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) + override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L) + override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getBoolean(ordinal) @@ -93,14 +92,15 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) @@ -110,14 +110,15 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) @@ -127,48 +128,51 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class LongColumnStats extends ColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue +private[sql] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getLong(ordinal) + val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + sizeInBytes += INT.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DoubleColumnStats extends ColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue +private[sql] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDouble(ordinal) + val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + sizeInBytes += LONG.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) @@ -178,48 +182,33 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) -} - -private[sql] class FixedDecimalColumnStats extends ColumnStats { - protected var upper: Decimal = null - protected var lower: Decimal = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize - } - } - - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class IntColumnStats extends ColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue +private[sql] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getInt(ordinal) + val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += INT.defaultSize + sizeInBytes += DOUBLE.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[UTF8String] @@ -229,46 +218,52 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends ColumnStats { - protected var upper: Timestamp = null - protected var lower: Timestamp = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { +private[sql] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Timestamp] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += TIMESTAMP.defaultSize + sizeInBytes += BINARY.actualSize(row, ordinal) } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class BinaryColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = { +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += BINARY.actualSize(row, ordinal) + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize } } - override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { sizeInBytes += GENERIC.actualSize(row, ordinal) } } - override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } + +private[sql] class DateColumnStats extends IntColumnStats + +private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 20be5ca9d0046..fc72360c88fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from @@ -63,7 +63,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this * method to avoid boxing/unboxing costs whenever possible. */ - def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { append(getField(row, ordinal), buffer) } @@ -71,13 +71,13 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable * length types such as byte arrays and strings. */ - def actualSize(row: Row, ordinal: Int): Int = defaultSize + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs * whenever possible. */ - def getField(row: Row, ordinal: Int): JvmType + def getField(row: InternalRow, ordinal: Int): JvmType /** * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing @@ -89,7 +89,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to(toOrdinal) = from(fromOrdinal) } @@ -118,7 +118,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putInt(row.getInt(ordinal)) } @@ -134,9 +134,9 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -146,7 +146,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putLong(row.getLong(ordinal)) } @@ -162,9 +162,9 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -174,7 +174,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putFloat(row.getFloat(ordinal)) } @@ -190,9 +190,9 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -202,7 +202,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putDouble(row.getDouble(ordinal)) } @@ -218,9 +218,9 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -230,7 +230,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (v) 1: Byte else 0: Byte) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } @@ -244,9 +244,9 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -256,7 +256,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(row.getByte(ordinal)) } @@ -272,9 +272,9 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -284,7 +284,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putShort(row.getShort(ordinal)) } @@ -300,15 +300,15 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getString(ordinal).getBytes("utf-8").length + 4 } @@ -321,18 +321,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - UTF8String(stringBytes) + UTF8String.fromBytes(stringBytes) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): UTF8String = { + override def getField(row: InternalRow, ordinal: Int): UTF8String = { row(ordinal).asInstanceOf[UTF8String] } - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.update(toOrdinal, from(fromOrdinal)) } } @@ -346,7 +346,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { buffer.putInt(v) } - override def getField(row: Row, ordinal: Int): Int = { + override def getField(row: InternalRow, ordinal: Int): Int = { row(ordinal).asInstanceOf[Int] } @@ -355,22 +355,20 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer): Timestamp = { - val timestamp = new Timestamp(buffer.getLong()) - timestamp.setNanos(buffer.getInt()) - timestamp +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { + override def extract(buffer: ByteBuffer): Long = { + buffer.getLong } - override def append(v: Timestamp, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime).putInt(v.getNanos) + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Timestamp = { - row(ordinal).asInstanceOf[Timestamp] + override def getField(row: InternalRow, ordinal: Int): Long = { + row(ordinal).asInstanceOf[Long] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row(ordinal) = value } } @@ -389,7 +387,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) buffer.putLong(v.toUnscaledLong) } - override def getField(row: Row, ordinal: Int): Decimal = { + override def getField(row: InternalRow, ordinal: Int): Decimal = { row(ordinal).asInstanceOf[Decimal] } @@ -407,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -428,7 +426,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { row(ordinal).asInstanceOf[Array[Byte]] } } @@ -441,7 +439,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { SparkSqlSerializer.serialize(row(ordinal)) } } @@ -449,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT case IntegerType => INT + case DateType => DATE case LongType => LONG + case TimestampType => TIMESTAMP case FloatType => FLOAT case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT case StringType => STRING case BinaryType => BINARY - case DateType => DATE - case TimestampType => TIMESTAMP case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) case _ => GENERIC diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 3db26fad2b92f..cb1fd4947fdbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,21 +19,16 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.{Accumulable, Accumulator, Accumulators} -import org.apache.spark.sql.catalyst.expressions - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Accumulable, Accumulator, Accumulators} private[sql] object InMemoryRelation { def apply( @@ -45,7 +40,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } -private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -56,12 +51,12 @@ private[sql] case class InMemoryRelation( tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, private var _statistics: Statistics = null, - private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) + private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) } else { _batchStats } @@ -151,7 +146,8 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -267,7 +263,7 @@ private[sql] case class InMemoryColumnarTableScan( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { if (enableAccumulators) { readPartitions.setValue(0) readBatches.setValue(0) @@ -296,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = { + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => @@ -306,15 +302,15 @@ private[sql] case class InMemoryColumnarTableScan( } // Extract rows via column accessors - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val rowLen = nextRow.length - override def next(): Row = { + override def next(): InternalRow = { var i = 0 while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) i += 1 } - if (attributes.isEmpty) Row.empty else nextRow + if (attributes.isEmpty) InternalRow.empty else nextRow } override def hasNext: Boolean = columnAccessors(0).hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index f1f494ac26d0c..ba47bc783f31e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow /** * A stackable trait used for building byte buffer for a column containing null values. Memory @@ -52,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int): Unit = { + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 8e2a1af6dae78..39b21ddb47ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType @@ -66,7 +66,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { compressionEncoders(i).gatherCompressibilityStats(row, ordinal) @@ -74,7 +74,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] } } - abstract override def appendFrom(row: Row, ordinal: Int): Unit = { + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 17c2d9b111188..4eaec6d853d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType private[sql] trait Encoder[T <: AtomicType] { - def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} + def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 534ae90ddbc8b..5abc1259a19ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -22,8 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.columnar._ import org.apache.spark.sql.types._ @@ -96,7 +95,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { override def compressedSize: Int = _compressedSize - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize @@ -217,7 +216,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) if (!overflow) { @@ -310,7 +309,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } @@ -404,7 +403,7 @@ private[sql] case object IntDelta extends CompressionScheme { private var prevValue: Int = _ - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = row.getInt(ordinal) val delta = value - prevValue @@ -484,7 +483,7 @@ private[sql] case object LongDelta extends CompressionScheme { private var prevValue: Long = _ - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = row.getLong(ordinal) val delta = value - prevValue diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 8d16749697aa2..6e8a5ef18ab62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.execution import java.util.HashMap import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.SQLContext /** * :: DeveloperApi :: @@ -121,11 +119,11 @@ case class Aggregate( } } - protected override def doExecute(): RDD[Row] = attachTree(this, "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() var i = 0 @@ -147,10 +145,10 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() val currentGroup = groupingProjection(currentRow) @@ -167,7 +165,7 @@ case class Aggregate( } } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = @@ -177,7 +175,7 @@ case class Aggregate( override final def hasNext: Boolean = hashTableIter.hasNext - override final def next(): Row = { + override final def next(): InternalRow = { val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 5fcc48a67948b..a4b38d364d54a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -103,7 +103,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - query.queryExecution.executedPlan, + sqlContext.executePlan(query.logicalPlan).executedPlan, tableName)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index f25d10fec0411..edc64a03335d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,29 +17,20 @@ package org.apache.spark.sql.execution -import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair - -object Exchange { - /** - * Returns true when the ordering expressions are a subset of the key. - * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. - */ - def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { - desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) - } -} +import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} /** * :: DeveloperApi :: @@ -91,11 +82,7 @@ case class Exchange( shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) - if (newOrdering.nonEmpty) { - // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`, - // which requires a defensive copy. - true - } else if (sortBasedShuffleOn) { + if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is @@ -106,8 +93,11 @@ case class Exchange( } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an ordering and the record serializer has certain - // properties. If this optimization is enabled, we can safely avoid the copy. + // shuffle dependency does not specify an aggregator or ordering and the record serializer + // has certain properties. If this optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only + // need to check whether the optimization is enabled and supported by our serializer. // // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false @@ -118,9 +108,12 @@ case class Exchange( // both cases, we must copy. true } - } else { + } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { // We're using hash-based shuffle, so we don't need to copy. false + } else { + // Catch-all case to safely handle any future ShuffleManager implementations. + true } } @@ -143,7 +136,6 @@ case class Exchange( private def getSerializer( keySchema: Array[DataType], valueSchema: Array[DataType], - hasKeyOrdering: Boolean, numPartitions: Int): Serializer = { // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. @@ -159,7 +151,7 @@ case class Exchange( val serializer = if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering) + new SparkSqlSerializer2(keySchema, valueSchema) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) @@ -168,12 +160,12 @@ case class Exchange( serializer } - protected override def doExecute(): RDD[Row] = attachTree(this , "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => val keySchema = expressions.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions) + val serializer = getSerializer(keySchema, valueSchema, numPartitions) val part = new HashPartitioner(numPartitions) val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { @@ -184,27 +176,24 @@ case class Exchange( } else { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[Row, Row]() + val mutablePair = new MutablePair[InternalRow, InternalRow]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } } - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } + val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(serializer) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, newOrdering.nonEmpty, numPartitions) + val serializer = getSerializer(keySchema, null, numPartitions) val childRdd = child.execute() val part: Partitioner = { // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. val rddForSampling = childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null]() + val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } // TODO: RangePartitioner should take an Ordering. @@ -216,32 +205,31 @@ case class Exchange( childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null]() + val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row, null)) } } - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } + val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) shuffled.setSerializer(serializer) shuffled.map(_._1) case SinglePartition => val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, hasKeyOrdering = false, 1) + val serializer = getSerializer(null, valueSchema, numPartitions = 1) val partitioner = new HashPartitioner(1) val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } + child.execute().mapPartitions { + iter => iter.map(r => (null, r.copy())) + } } else { child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, Row]() + val mutablePair = new MutablePair[Null, InternalRow]() iter.map(r => mutablePair.update(null, r)) } } - val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) shuffled.setSerializer(serializer) shuffled.map(_._2) @@ -306,29 +294,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering val needsShuffle = child.outputPartitioning != partitioning - val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) - if (needSort && needsShuffle && canSortWithShuffle) { - Exchange(partitioning, rowOrdering, child) + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) } else { - val withShuffle = if (needsShuffle) { - Exchange(partitioning, Nil, child) - } else { - child - } + child + } - val withSort = if (needSort) { - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) } else { - withShuffle + Sort(rowOrdering, global = false, withShuffle) } - - withSort + } else { + withShuffle } + + withSort } if (meetsRequirements && compatible && !needsAnySort) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f931dc95ef575..da27a753a710f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} @@ -31,7 +31,7 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length val mutableRow = new GenericMutableRow(numColumns) @@ -51,7 +51,7 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => val numColumns = outputTypes.length val mutableRow = new GenericMutableRow(numColumns) @@ -70,7 +70,9 @@ object RDDConversions { } /** Logical plan node for scanning data from an RDD. */ -private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) +private[sql] case class LogicalRDD( + output: Seq[Attribute], + rdd: RDD[InternalRow])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil @@ -91,13 +93,15 @@ private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlCon } /** Physical plan node for scanning data from an RDD. */ -private[sql] case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - protected override def doExecute(): RDD[Row] = rdd +private[sql] case class PhysicalRDD( + output: Seq[Attribute], + rdd: RDD[InternalRow]) extends LeafNode { + protected override def doExecute(): RDD[InternalRow] = rdd } /** Logical plan node for scanning data from a local collection. */ private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) +case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index f16ca36909fab..42a0c1be4f694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} /** * Apply the all of the GroupExpressions to every input row, hence we will get @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit */ @DeveloperApi case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -43,22 +42,22 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - protected override def doExecute(): RDD[Row] = attachTree(this, "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // TODO Move out projection objects creation and transfer to // workers via closure. However we can't assume the Projection // is serializable because of the code gen, so we have to // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray + val groups = projections.map(ee => newProjection(ee, child.output)).toArray - new Iterator[Row] { - private[this] var result: Row = _ + new Iterator[InternalRow] { + private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state - private[this] var input: Row = _ + private[this] var input: InternalRow = _ override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext - override final def next(): Row = { + override final def next(): InternalRow = { if (idx <= 0) { // in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple input = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index dd02c1f4573bb..c1665f78a960e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ * For lazy computing, be sure the generator.terminate() called in the very last * TODO reusing the CompletionIterator? */ -private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row]) - extends Iterator[Row] { +private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) + extends Iterator[InternalRow] { lazy val results = func().toIterator override def hasNext: Boolean = results.hasNext - override def next(): Row = results.next() + override def next(): InternalRow = results.next() } /** @@ -58,11 +58,11 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { child.execute().mapPartitions { iter => - val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) + val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 3e27c1bde2dfd..44930f82b53a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -66,7 +66,7 @@ case class GeneratedAggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -118,7 +118,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case cs @ CombineSum(expr) => - val calcType = expr.dataType + val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -129,7 +129,7 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", calcType, nullable = true)() val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -214,18 +214,18 @@ case class GeneratedAggregate( }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) } - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) + case e: Expression => + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) @@ -238,11 +238,6 @@ case class GeneratedAggregate( StructType(fields) } - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +260,7 @@ case class GeneratedAggregate( val resultProjectionBuilder = newMutableProjection( resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow3 @@ -273,7 +268,7 @@ case class GeneratedAggregate( if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: Row = null + var currentRow: InternalRow = null updateProjection.target(buffer) while (iter.hasNext) { @@ -283,31 +278,31 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { + } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics ) while (iter.hasNext) { - val currentRow: Row = iter.next() - val groupKey: Row = groupProjection(currentRow) + val currentRow: InternalRow = iter.next() + val groupKey: InternalRow = groupProjection(currentRow) val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() def hasNext: Boolean = mapIterator.hasNext - def next(): Row = { + def next(): InternalRow = { val entry = mapIterator.next() val result = resultProjection(joinedRow(entry.key, entry.value)) if (hasNext) { @@ -323,12 +318,9 @@ case class GeneratedAggregate( } } } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[Row, MutableRow]() + val buffers = new java.util.HashMap[InternalRow, MutableRow]() - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() val currentGroup = groupProjection(currentRow) @@ -342,13 +334,13 @@ case class GeneratedAggregate( updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val resultIterator = buffers.entrySet.iterator() private[this] val resultProjection = resultProjectionBuilder() def hasNext: Boolean = resultIterator.hasNext - def next(): Row = { + def next(): InternalRow = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 03bee80ad7f38..cd341180b6100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.Attribute /** * Physical plan node for scanning data from a local collection. */ -private[sql] case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { +private[sql] case class LocalTableScan( + output: Seq[Attribute], + rows: Seq[InternalRow]) extends LeafNode { private lazy val rdd = sqlContext.sparkContext.parallelize(rows) - protected override def doExecute(): RDD[Row] = rdd + protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 435ac011178de..47f56b2b7ebe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ @@ -79,11 +80,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** - * Returns the result of this query as an RDD[Row] by delegating to doExecute + * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. * Concrete implementations of SparkPlan should override doExecute instead. */ - final def execute(): RDD[Row] = { + final def execute(): RDD[InternalRow] = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { doExecute() } @@ -91,9 +92,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Overridden by concrete implementations of SparkPlan. - * Produces the result of the query as an RDD[Row] + * Produces the result of the query as an RDD[InternalRow] */ - protected def doExecute(): RDD[Row] + protected def doExecute(): RDD[InternalRow] /** * Runs this query returning the result as an array. @@ -117,7 +118,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val childRDD = execute().map(_.copy()) - val buf = new ArrayBuffer[Row] + val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length var partsScanned = 0 while (buf.size < n && partsScanned < totalParts) { @@ -140,7 +141,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val sc = sqlContext.sparkContext val res = - sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p, + allowLocal = false) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += numPartsToTry @@ -154,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { + if (codegenEnabled && expressions.forall(_.isThreadSafe)) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) @@ -166,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled) { + if(codegenEnabled && expressions.forall(_.isThreadSafe)) { GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) @@ -175,15 +177,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { - if (codegenEnabled) { + expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + if (codegenEnabled && expression.isThreadSafe) { GeneratePredicate.generate(expression, inputSchema) } else { InterpretedPredicate.create(expression, inputSchema) } } - protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + protected def newOrdering( + order: Seq[SortOrder], + inputSchema: Seq[Attribute]): Ordering[InternalRow] = { if (codegenEnabled) { GenerateOrdering.generate(order, inputSchema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index eea15aff5dbcf..b19ad4f1c563e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -20,22 +20,20 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.types.Decimal - import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool -import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.util.MutablePair - +import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.{SparkConf, SparkEnv} private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { @@ -43,6 +41,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) @@ -139,7 +138,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { val iterator = hs.iterator while(iterator.hasNext) { val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values) + rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) } } @@ -150,7 +149,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { var i = 0 while (i < numItems) { val row = - new GenericRow(rowSerializer.read( + new GenericInternalRow(rowSerializer.read( kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 256d527d7b636..056d435eecd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.execution import java.io._ import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.ClassTag -import org.apache.spark.serializer._ import org.apache.spark.Logging +import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in @@ -86,7 +87,6 @@ private[sql] class Serializer2SerializationStream( private[sql] class Serializer2DeserializationStream( keySchema: Array[DataType], valueSchema: Array[DataType], - hasKeyOrdering: Boolean, in: InputStream) extends DeserializationStream with Logging { @@ -96,14 +96,9 @@ private[sql] class Serializer2DeserializationStream( if (schema == null) { () => null } else { - if (hasKeyOrdering) { - // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row. - () => new GenericMutableRow(schema.length) - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } + // It is safe to reuse the mutable row. + val mutableRow = new SpecificMutableRow(schema) + () => mutableRow } } @@ -133,8 +128,7 @@ private[sql] class Serializer2DeserializationStream( private[sql] class SparkSqlSerializer2Instance( keySchema: Array[DataType], - valueSchema: Array[DataType], - hasKeyOrdering: Boolean) + valueSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -151,7 +145,7 @@ private[sql] class SparkSqlSerializer2Instance( } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s) + new Serializer2DeserializationStream(keySchema, valueSchema, s) } } @@ -164,14 +158,13 @@ private[sql] class SparkSqlSerializer2Instance( */ private[sql] class SparkSqlSerializer2( keySchema: Array[DataType], - valueSchema: Array[DataType], - hasKeyOrdering: Boolean) + valueSchema: Array[DataType]) extends Serializer with Logging with Serializable{ def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering) + new SparkSqlSerializer2Instance(keySchema, valueSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers @@ -244,7 +237,7 @@ private[sql] object SparkSqlSerializer2 { out.writeShort(row.getShort(i)) } - case IntegerType => + case IntegerType | DateType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -252,7 +245,7 @@ private[sql] object SparkSqlSerializer2 { out.writeInt(row.getInt(i)) } - case LongType => + case LongType | TimestampType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -276,59 +269,39 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } - case decimal: DecimalType => + case StringType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + val bytes = row.getAs[UTF8String](i).getBytes out.writeInt(bytes.length) out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) } - case DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getAs[Int](i)) - } - - case TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val timestamp = row.getAs[java.sql.Timestamp](i) - val time = timestamp.getTime - val nanos = timestamp.getNanos - out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. - out.writeInt(nanos) // Write the nanoseconds part. - } - - case StringType => + case BinaryType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getAs[Array[Byte]](i) out.writeInt(bytes.length) out.write(bytes) } - case BinaryType => + case decimal: DecimalType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray out.writeInt(bytes.length) out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) } } i += 1 @@ -341,7 +314,7 @@ private[sql] object SparkSqlSerializer2 { */ def createDeserializationFunction( schema: Array[DataType], - in: DataInputStream): (MutableRow) => Row = { + in: DataInputStream): (MutableRow) => InternalRow = { if (schema == null) { (mutableRow: MutableRow) => null } else { @@ -375,14 +348,14 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setShort(i, in.readShort()) } - case IntegerType => + case IntegerType | DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { mutableRow.setInt(i, in.readInt()) } - case LongType => + case LongType | TimestampType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { @@ -403,57 +376,39 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } - case decimal: DecimalType => + case StringType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - - case DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readInt()) + mutableRow.update(i, UTF8String.fromBytes(bytes)) } - case TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val time = in.readLong() // Read the milliseconds value. - val nanos = in.readInt() // Read the nanoseconds part. - val timestamp = new Timestamp(time) - timestamp.setNanos(nanos) - mutableRow.update(i, timestamp) - } - - case StringType => + case BinaryType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, UTF8String(bytes)) + mutableRow.update(i, bytes) } - case BinaryType => + case decimal: DecimalType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { + // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) } } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a1331a39151a..5daf86d817586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} @@ -52,6 +52,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + object CanBroadcast { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case BroadcastHint(p) => Some(p) + case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) + case _ => None + } + } + /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. @@ -80,15 +92,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join @@ -202,13 +210,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - protected lazy val singleRowRdd = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } @@ -300,8 +311,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case logical.Expand(projections, output, child) => - execution.Expand(projections, output, planLater(child)) :: Nil + case e @ logical.Expand(_, _, _, child) => + execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Window(projectList, windowExpressions, spec, child) => @@ -329,6 +340,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil + case BroadcastHint(child) => apply(child) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index c4327ce262ac5..fd6f1d7ae1255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution import java.util import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.util.collection.CompactBuffer /** @@ -112,16 +111,16 @@ case class Window( } } - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - new Iterator[Row] { + new Iterator[InternalRow] { // Although input rows are grouped based on windowSpec.partitionSpec, we need to // know when we have a new partition. // This is to manually construct an ordering that can be used to compare rows. // TODO: We may want to have a newOrdering that takes BoundReferences. // So, we can take advantave of code gen. - private val partitionOrdering: Ordering[Row] = + private val partitionOrdering: Ordering[InternalRow] = RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType)) // This is used to project expressions for the partition specification. @@ -137,13 +136,13 @@ case class Window( // The number of buffered rows in the inputRowBuffer (the size of the current partition). var partitionSize: Int = 0 // The buffer used to buffer rows in a partition. - var inputRowBuffer: CompactBuffer[Row] = _ + var inputRowBuffer: CompactBuffer[InternalRow] = _ // The partition key of the current partition. - var currentPartitionKey: Row = _ + var currentPartitionKey: InternalRow = _ // The partition key of next partition. - var nextPartitionKey: Row = _ + var nextPartitionKey: InternalRow = _ // The first row of next partition. - var firstRowInNextPartition: Row = _ + var firstRowInNextPartition: InternalRow = _ // Indicates if this partition is the last one in the iter. var lastPartition: Boolean = false @@ -316,7 +315,7 @@ case class Window( !lastPartition || (rowPosition < partitionSize) } - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { if (rowPosition == partitionSize) { // All rows of this buffer have been consumed. @@ -353,7 +352,7 @@ case class Window( // Fetch the next partition. private def fetchNextPartition(): Unit = { // Create a new buffer for input rows. - inputRowBuffer = new CompactBuffer[Row]() + inputRowBuffer = new CompactBuffer[InternalRow]() // We already have the first row for this partition // (recorded in firstRowInNextPartition). Add it back. inputRowBuffer += firstRowInNextPartition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fb42072f9d5a7..647c4ab5cb651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.{CompletionIterator, MutablePair} +import org.apache.spark.{HashPartitioner, SparkEnv} /** * :: DeveloperApi :: @@ -37,9 +38,9 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter => - val resuableProjection = buildProjection() - iter.map(resuableProjection) + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map(reusableProjection) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -52,9 +53,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output) + @transient lazy val conditionEvaluator: (InternalRow) => Boolean = + newPredicate(condition, child.output) - protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter => + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } @@ -83,7 +85,7 @@ case class Sample( override def output: Seq[Attribute] = child.output // TODO: How to pick seed? - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) } else { @@ -99,7 +101,8 @@ case class Sample( case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output - protected override def doExecute(): RDD[Row] = sparkContext.union(children.map(_.execute())) + protected override def doExecute(): RDD[InternalRow] = + sparkContext.union(children.map(_.execute())) } /** @@ -124,19 +127,19 @@ case class Limit(limit: Int, child: SparkPlan) override def executeCollect(): Array[Row] = child.executeTake(limit) - protected override def doExecute(): RDD[Row] = { - val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { + protected override def doExecute(): RDD[InternalRow] = { + val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.take(limit).map(row => (false, row.copy())) } } else { child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Boolean, Row]() + val mutablePair = new MutablePair[Boolean, InternalRow]() iter.take(limit).map(row => mutablePair.update(false, row)) } } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) + val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -144,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -157,7 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) @@ -166,7 +181,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder } @@ -186,7 +201,7 @@ case class Sort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) iterator.map(_.copy()).toArray.sorted(ordering).iterator @@ -214,14 +229,14 @@ case class ExternalSort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy, null))) val baseIterator = sorter.iterator.map(_._1) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } @@ -239,7 +254,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().map(_.copy()).coalesce(numPartitions, shuffle) } } @@ -254,7 +269,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } } @@ -268,7 +283,7 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = children.head.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } @@ -283,5 +298,5 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = child.execute() + protected override def doExecute(): RDD[InternalRow] = child.execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 49b361e96b2d6..5e9951f248ff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution +import java.util.NoSuchElementException + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -64,9 +66,9 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - protected override def doExecute(): RDD[Row] = { - val converted = sideEffectResult.map(r => - CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + protected override def doExecute(): RDD[InternalRow] = { + val convert = CatalystTypeConverters.createToCatalystConverter(schema) + val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow]) sqlContext.sparkContext.parallelize(converted, 1) } } @@ -75,48 +77,92 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) - extends RunnableCommand with Logging { +case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { + + private def keyValueOutput: Seq[Attribute] = { + val schema = StructType( + StructField("key", StringType, false) :: + StructField("value", StringType, false) :: Nil) + schema.toAttributes + } - override def run(sqlContext: SQLContext): Seq[Row] = kv match { + private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - if (value.toInt < 1) { - val msg = s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + - "determining the number of reducers is not supported." - throw new IllegalArgumentException(msg) - } else { - sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } } + (keyValueOutput, runFunc) // Configures a single property. case Some((key, Some(value))) => - sqlContext.setConf(key, value) - Seq(Row(s"$key=$value")) + val runFunc = (sqlContext: SQLContext) => { + sqlContext.setConf(key, value) + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) - // Queries all key-value pairs that are set in the SQLConf of the sqlContext. - // Notice that different from Hive, here "SET -v" is an alias of "SET". // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) - case Some(("-v", None)) | None => - sqlContext.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + // Queries all key-value pairs that are set in the SQLConf of the sqlContext. + case None => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq + } + (keyValueOutput, runFunc) + + // Queries all properties along with their default values and docs that are defined in the + // SQLConf of the sqlContext. + case Some(("-v", None)) => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => + Row(key, defaultValue, doc) + } + } + val schema = StructType( + StructField("key", StringType, false) :: + StructField("default", StringType, false) :: + StructField("meaning", StringType, false) :: Nil) + (schema.toAttributes, runFunc) // Queries the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) + } + (keyValueOutput, runFunc) // Queries a single property. case Some((key, None)) => - Seq(Row(s"$key=${sqlContext.getConf(key, "")}")) + val runFunc = (sqlContext: SQLContext) => { + val value = + try { + sqlContext.getConf(key) + } catch { + case _: NoSuchElementException => "" + } + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) } + + override val output: Seq[Attribute] = _output + + override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dffb265601bdb..2964edac1aba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -46,7 +48,7 @@ package object debug { */ implicit class DebugSQLContext(sqlContext: SQLContext) { def debug(): Unit = { - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) } } @@ -125,11 +127,11 @@ package object debug { } } - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 var i = 0 @@ -154,7 +156,7 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: Row, StructType(fields)) => + case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (s: Seq[_], ArrayType(elemType, _)) => s.foreach(typeCheck(_, elemType)) @@ -170,6 +172,8 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (_: Int, DateType) => + case (_: Long, TimestampType) => case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") @@ -193,7 +197,7 @@ package object debug { def children: List[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().map { row => try typeCheck(row, child.schema) catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index e228a60c9029f..3b217348b7b7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.{InternalRow, LeafExpression} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -43,9 +43,11 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def dataType: DataType = LongType - override def eval(input: Row): Long = { + override def eval(input: InternalRow): Long = { val currentCount = count count += 1 (TaskContext.get().partitionId().toLong << 33) + currentCount } + + override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 1272793f88cd0..12c2eed0d6b7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, InternalRow} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -31,5 +31,5 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def dataType: DataType = IntegerType - override def eval(input: Row): Int = TaskContext.get().partitionId() + override def eval(input: InternalRow): Int = TaskContext.get().partitionId() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index b8b12be8756f9..2d2e1b92b86be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD -import org.apache.spark.util.ThreadUtils - import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Expression, InternalRow} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils /** * :: DeveloperApi :: @@ -61,12 +60,12 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types - val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a32e5fc4f7ea4..412a3d4178e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,10 +38,10 @@ case class BroadcastLeftSemiJoinHash( override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null // Create a Hash set of buildKeys while (buildIter.hasNext) { @@ -50,7 +50,8 @@ case class BroadcastLeftSemiJoinHash( if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - hashSet.add(rowKey) + // rowKey may be not serializable (from codegen) + hashSet.add(rowKey.copy()) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index caad3dfbe1c5e..0b2cf8e12a6c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -61,13 +61,14 @@ case class BroadcastNestedLoopJoin( @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()) + .collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] + val matchedRows = new CompactBuffer[InternalRow] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -118,8 +119,8 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() + val broadcastRowsWithNulls: Seq[InternalRow] = { + val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 191c00cb55da2..261b4724159fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 851de1685509a..3a4196a90d14a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -49,11 +49,13 @@ trait HashJoin { @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = + protected def hashJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ + new Iterator[InternalRow] { + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -65,7 +67,7 @@ trait HashJoin { (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next(): Row = { + override final def next(): InternalRow = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 45574392996ca..e41538ec1fc1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -48,7 +48,8 @@ case class HashOuterJoin( case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } override def requiredChildDistribution: Seq[ClusteredDistribution] = @@ -63,24 +64,27 @@ case class HashOuterJoin( case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] - @transient private[this] lazy val leftNullRow = new GenericRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericRow(right.output.length) + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + condition.map( + newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. private[this] def leftOuterIterator( - key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { - val ret: Iterable[Row] = { + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() @@ -98,12 +102,15 @@ case class HashOuterJoin( } private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow): Iterator[InternalRow] = { - val ret: Iterable[Row] = { + val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy + case l if boundCondition(joinedRow.withLeft(l)) => + joinedRow.copy } if (temp.size == 0) { joinedRow.withLeft(leftNullRow).copy :: Nil @@ -118,14 +125,14 @@ case class HashOuterJoin( } private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row], - joinedRow: JoinedRow): Iterator[Row] = { + key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], + joinedRow: JoinedRow): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => + leftIter.iterator.flatMap[InternalRow] { l => joinedRow.withLeft(l) var matched = false rightIter.zipWithIndex.collect { @@ -156,24 +163,25 @@ case class HashOuterJoin( joinedRow(leftNullRow, r).copy() } } else { - leftIter.iterator.map[Row] { l => + leftIter.iterator.map[InternalRow] { l => joinedRow(l, rightNullRow).copy() - } ++ rightIter.iterator.map[Row] { r => + } ++ rightIter.iterator.map[InternalRow] { r => joinedRow(leftNullRow, r).copy() } } } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + iter: Iterator[InternalRow], + keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { + val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() + existingMatchList = new CompactBuffer[InternalRow]() hashTable.put(rowKey, existingMatchList) } @@ -183,7 +191,7 @@ case class HashOuterJoin( hashTable } - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) @@ -216,7 +224,8 @@ case class HashOuterJoin( rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ab84c123e0c0b..e18c817975134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow} import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -30,7 +30,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: Row): CompactBuffer[Row] + def get(key: InternalRow): CompactBuffer[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -54,12 +54,12 @@ private[joins] sealed trait HashedRelation { * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ private[joins] final class GeneralHashedRelation( - private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { def this() = this(null) // Needed for serialization - override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -75,17 +75,18 @@ private[joins] final class GeneralHashedRelation( * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) +private[joins] +final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { def this() = this(null) // Needed for serialization - override def get(key: Row): CompactBuffer[Row] = { + override def get(key: InternalRow): CompactBuffer[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } - def getValue(key: Row): Row = hashTable.get(key) + def getValue(key: InternalRow): InternalRow = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -103,13 +104,13 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHa private[joins] object HashedRelation { def apply( - input: Iterator[Row], + input: Iterator[InternalRow], keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { // TODO: Use Spark's HashMap implementation. - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) - var currentRow: Row = null + val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) + var currentRow: InternalRow = null // Whether the join key is unique. If the key is unique, we can convert the underlying // hash map into one specialized for this. @@ -122,7 +123,7 @@ private[joins] object HashedRelation { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() + val newMatchList = new CompactBuffer[InternalRow]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -134,7 +135,7 @@ private[joins] object HashedRelation { } if (keyIsUnique) { - val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size) val iter = hashTable.entrySet().iterator() while (iter.hasNext) { val entry = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 036423e6faea4..2a6d4d1ab08bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -47,7 +47,7 @@ case class LeftSemiJoinBNL( @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 8ad27eae80ffb..20d74270afb48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow} import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -42,10 +42,10 @@ case class LeftSemiJoinHash( override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null // Create a Hash set of buildKeys while (buildIter.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 219525d9d85f3..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -43,7 +43,7 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 1a39fb4b96608..2abe65a71813d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -21,9 +21,7 @@ import java.util.NoSuchElementException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -60,29 +58,29 @@ case class SortMergeJoin( private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = keys.map(SortOrder(_, Ascending)) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[Row] { + new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: Row = _ - private[this] var rightElement: Row = _ - private[this] var leftKey: Row = _ - private[this] var rightKey: Row = _ - private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftElement: InternalRow = _ + private[this] var rightElement: InternalRow = _ + private[this] var leftKey: InternalRow = _ + private[this] var rightKey: InternalRow = _ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ private[this] var rightPosition: Int = -1 private[this] var stop: Boolean = false - private[this] var matchKey: Row = _ + private[this] var matchKey: InternalRow = _ // initialize iterator initialize() override final def hasNext: Boolean = nextMatchingPair() - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { // we are using the buffered right rows and run down left iterator val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) @@ -145,7 +143,7 @@ case class SortMergeJoin( fetchLeft() } } - rightMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 55f3ff4709013..9e1cff06c7eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -24,18 +24,19 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.{Accumulator, Logging => SparkLogging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.unsafe.types.UTF8String /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -54,10 +55,10 @@ private[spark] case class PythonUDF( override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - def nullable: Boolean = true + override def nullable: Boolean = true - override def eval(input: Row): Any = { - sys.error("PythonUDFs can not be directly evaluated.") + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") } } @@ -68,46 +69,52 @@ private[spark] case class PythonUDF( * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. - case p: EvaluatePython => p + case plan: EvaluatePython => plan - case l: LogicalPlan => + case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) if (udfs.isEmpty) { // If there aren't any, we are done. - l + plan } else { // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) // If there is more than one, we will add another evaluation operator in a subsequent pass. - val udf = udfs.head - - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = l.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartisian product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - l.output, - l.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) } } } @@ -141,7 +148,8 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) - case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) + case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal @@ -175,15 +183,17 @@ object EvaluatePython { }.toMap case (c, StructType(fields)) if c.getClass.isArray => - new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { case (e, f) => fromJava(e, f.dataType) - }): Row + }) case (c: java.util.Calendar, DateType) => - DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) + DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) + c.getTimeInMillis * 10000L + case (t: java.sql.Timestamp, TimestampType) => + DateTimeUtils.fromJavaTimestamp(t) case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) @@ -195,8 +205,10 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String(c) - case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) case (c, _) => c } @@ -230,7 +242,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => @@ -265,7 +277,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val row = new GenericMutableRow(1) iter.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: Row + row: InternalRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index c41c21c0eeb50..3ebbf96090a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{ArrayType, StructField, StructType} +import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -89,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -110,7 +111,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toSeq) - val resultRow = Row(justItems : _*) + val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 93383e5a62f11..b624ef7e8fa1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { @@ -81,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, @@ -110,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -119,14 +120,16 @@ private[sql] object StatFunctions extends Logging { rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 - // row.get(3) is the frequency + // row.get(2) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.setString(0, col1Item.toString) + countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 454af47913bf1..4d9a019058228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -36,7 +37,9 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions + * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -563,6 +566,22 @@ object functions { array((colName +: colNames).map(col) : _*) } + /** + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ + def broadcast(df: DataFrame): DataFrame = { + DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + } + /** * Returns the first column that is not null. * {{{ @@ -706,11 +725,19 @@ object functions { /** * Computes the square root of the specified float value. * - * @group normal_funcs + * @group math_funcs * @since 1.3.0 */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.5.0 + */ + def sqrt(colName: String): Column = sqrt(Column(colName)) + /** * Creates a new struct column. The input column must be a column in a [[DataFrame]], or * a derived column expression that is named (i.e. aliased). @@ -880,6 +907,24 @@ object functions { */ def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(e: Column): Column = Bin(e.expr) + + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(columnName: String): Column = bin(Column(columnName)) + /** * Computes the cube-root of the given value. * @@ -944,6 +989,15 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) + /** + * Returns the double value that is closer than any other to e, the base of the natural + * logarithms. + * + * @group math_funcs + * @since 1.5.0 + */ + def e(): Column = EulerNumber() + /** * Computes the exponential of the given value. * @@ -992,6 +1046,22 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Computes hex value of the given input + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(colName: String): Column = hex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * @@ -1074,7 +1144,23 @@ object functions { def log(columnName: String): Column = log(Column(columnName)) /** - * Computes the logarithm of the given value in Base 10. + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + + /** + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1082,7 +1168,7 @@ object functions { def log10(e: Column): Column = Log10(e.expr) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1105,6 +1191,31 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) + /** + * Returns the double value that is closer than any other to pi, the ratio of the circumference + * of a circle to its diameter. + * + * @group math_funcs + * @since 1.5.0 + */ + def pi(): Column = Pi() + + /** + * Computes the logarithm of the given column in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(expr: Column): Column = Log2(expr.expr) + + /** + * Computes the logarithm of the given value in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(columnName: String): Column = log2(Column(columnName)) + /** * Returns the value of the first argument raised to the power of the second argument. * @@ -1299,6 +1410,79 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Misc functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(e: Column): Column = Md5(e.expr) + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(columnName: String): Column = md5(Column(columnName)) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // String functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the length of a given string value + * @group string_funcs + * @since 1.5.0 + */ + def strlen(e: Column): Column = StringLength(e.expr) + + /** + * Computes the length of a given string column + * @group string_funcs + * @since 1.5.0 + */ + def strlen(columnName: String): Column = strlen(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1325,7 +1509,7 @@ object functions { (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1333,9 +1517,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1469,9 +1655,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1480,9 +1668,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1491,9 +1681,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1502,9 +1694,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1513,9 +1707,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1524,9 +1720,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1535,9 +1733,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1546,9 +1746,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1557,9 +1759,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1568,9 +1772,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1579,9 +1785,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1594,13 +1802,34 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.5.0 + */ + def callUDF(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs * @since 1.4.0 + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ + @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { UnresolvedFunction(udfName, cols.map(_.expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index db68b9c86db1b..30c5f4ca3e1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -22,12 +22,13 @@ import java.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.{InternalRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} /** * Data corresponding to one partition of a JDBCRDD. @@ -54,7 +55,7 @@ private[sql] object JDBCRDD extends Logging { val answer = sqlType match { // scalastyle:off case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited } + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } case java.sql.Types.BINARY => BinaryType case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType @@ -210,7 +211,7 @@ private[sql] object JDBCRDD extends Logging { fqTable: String, requiredColumns: Array[String], filters: Array[Filter], - parts: Array[Partition]): RDD[Row] = { + parts: Array[Partition]): RDD[InternalRow] = { val dialect = JdbcDialects.get(url) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( @@ -239,7 +240,7 @@ private[sql] class JDBCRDD( filters: Array[Filter], partitions: Array[Partition], properties: Properties) - extends RDD[Row](sc, Nil) { + extends RDD[InternalRow](sc, Nil) { /** * Retrieve the list of partitions corresponding to this RDD. @@ -347,12 +348,12 @@ private[sql] class JDBCRDD( /** * Runs the SQL query against the JDBC driver. */ - override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] - { + override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = + new Iterator[InternalRow] { var closed = false var finished = false var gotNext = false - var nextValue: Row = null + var nextValue: InternalRow = null context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] @@ -374,7 +375,7 @@ private[sql] class JDBCRDD( val conversions = getConversions(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) - def getNext(): Row = { + def getNext(): InternalRow = { if (rs.next()) { var i = 0 while (i < conversions.length) { @@ -382,10 +383,10 @@ private[sql] class JDBCRDD( conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) case DateConversion => - // DateUtils.fromJavaDate does not handle null value, so we need to check it. + // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { - mutableRow.update(i, DateUtils.fromJavaDate(dateVal)) + mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal)) } else { mutableRow.update(i, null) } @@ -416,8 +417,14 @@ private[sql] class JDBCRDD( case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) + case TimestampConversion => + val t = rs.getTimestamp(pos) + if (t != null) { + mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t)) + } else { + mutableRow.update(i, null) + } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) @@ -436,7 +443,7 @@ private[sql] class JDBCRDD( mutableRow } else { finished = true - null.asInstanceOf[Row] + null.asInstanceOf[InternalRow] } } @@ -479,7 +486,7 @@ private[sql] class JDBCRDD( !finished } - override def next(): Row = { + override def next(): InternalRow = { if (!hasNext) { throw new NoSuchElementException("End of stream") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 30f9190d45bf8..4d3aac464c538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -23,10 +23,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Instructions on how to partition the table among workers. @@ -138,7 +137,7 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts) + parts).map(_.asInstanceOf[Row]) } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 565d10247f10e..afe2c6c11ac69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -43,7 +43,7 @@ private[sql] object InferSchema { } // perform schema inference on each row and merge afterwards - schemaData.mapPartitions { iter => + val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() iter.map { row => try { @@ -55,8 +55,13 @@ private[sql] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match { - case st: StructType => nullTypeToStringType(st) + }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + + canonicalizeType(rootType) match { + case Some(st: StructType) => st + case _ => + // canonicalizeType erases all empty structs, including the only one we want to keep + StructType(Seq()) } } @@ -116,22 +121,35 @@ private[sql] object InferSchema { } } - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } + /** + * Convert NullType to StringType and remove StructTypes with no fields + */ + private def canonicalizeType: DataType => Option[DataType] = { + case at@ArrayType(elementType, _) => + for { + canonicalType <- canonicalizeType(elementType) + } yield { + at.copy(canonicalType) + } - StructField(fieldName, newType, nullable) - } + case StructType(fields) => + val canonicalFields = for { + field <- fields + if field.name.nonEmpty + canonicalType <- canonicalizeType(field.dataType) + } yield { + field.copy(dataType = canonicalType) + } + + if (canonicalFields.nonEmpty) { + Some(StructType(canonicalFields)) + } else { + // per SPARK-8093: empty structs should be deleted + None + } - StructType(fields) + case NullType => Some(StringType) + case other => Some(other) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index c772cd1f53e53..69bf13e1e5a6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,10 +22,10 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StructField, StructType} -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} private[sql] class DefaultSource @@ -154,12 +154,12 @@ private[sql] class JSONRelation( JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } else { JsonRDD.jsonStringToRow( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } } @@ -168,12 +168,12 @@ private[sql] class JSONRelation( JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } else { JsonRDD.jsonStringToRow( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala index 325f54b6808a8..1e6b1198d245b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -21,7 +21,7 @@ import scala.collection.Map import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 0e223758051a6..6222addc9aa3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.json import java.io.ByteArrayOutputStream -import java.sql.Timestamp import scala.collection.Map @@ -26,15 +25,17 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + private[sql] object JacksonParser { def apply( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { parseJson(json, schema, columnNameOfCorruptRecords) } @@ -55,27 +56,27 @@ private[sql] object JacksonParser { convertField(factory, parser, schema) case (VALUE_STRING, StringType) => - UTF8String(parser.getText) + UTF8String.fromString(parser.getText) case (VALUE_STRING, _) if parser.getTextLength < 1 => // guard the non string type null case (VALUE_STRING, DateType) => - DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime) + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - new Timestamp(DateUtils.stringToTime(parser.getText).getTime) + DateTimeUtils.stringToTime(parser.getText).getTime * 10000L case (VALUE_NUMBER_INT, TimestampType) => - new Timestamp(parser.getLongValue) + parser.getLongValue * 10000L case (_, StringType) => val writer = new ByteArrayOutputStream() val generator = factory.createGenerator(writer, JsonEncoding.UTF8) generator.copyCurrentStructure(parser) generator.close() - UTF8String(writer.toByteArray) + UTF8String.fromBytes(writer.toByteArray) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => parser.getFloatValue @@ -129,7 +130,10 @@ private[sql] object JacksonParser { * * Fields in the json that are not defined in the requested schema will be dropped. */ - private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = { + private def convertObject( + factory: JsonFactory, + parser: JsonParser, + schema: StructType): InternalRow = { val row = new GenericMutableRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { @@ -153,7 +157,8 @@ private[sql] object JacksonParser { valueType: DataType): Map[UTF8String, Any] = { val builder = Map.newBuilder[UTF8String, Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType) + builder += + UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType) } builder.result() @@ -174,14 +179,14 @@ private[sql] object JacksonParser { private def parseJson( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { - def failedRecord(record: String): Seq[Row] = { + def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present val row = new GenericMutableRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String(record)) + row.update(corruptIndex, UTF8String.fromString(record)) } Seq(row) @@ -200,7 +205,7 @@ private[sql] object JacksonParser { // convertField wrap an object into a single value array when necessary. convertField(factory, parser, ArrayType(schema)) match { case null => failedRecord(record) - case list: Seq[Row @unchecked] => list + case list: Seq[InternalRow @unchecked] => list case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 7e1e21f5fbb99..73d9520d6f53f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.json -import java.sql.Timestamp - import scala.collection.Map import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} @@ -30,15 +28,17 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } @@ -319,7 +319,7 @@ private[sql] object JsonRDD extends Logging { parsed } catch { case e: JsonProcessingException => - Map(columnNameOfCorruptRecords -> UTF8String(record)) :: Nil + Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil } } }) @@ -393,16 +393,16 @@ private[sql] object JsonRDD extends Logging { value match { // only support string as date case value: java.lang.String => - DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) - case value: java.sql.Date => DateUtils.fromJavaDate(value) + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(value).getTime) + case value: java.sql.Date => DateTimeUtils.fromJavaDate(value) } } - private def toTimestamp(value: Any): Timestamp = { + private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L + case value: java.lang.Long => value * 10000L + case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 10000L } } @@ -411,7 +411,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => UTF8String(toString(value)) + case StringType => UTF8String.fromString(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.InternalType] case LongType => toLong(value) @@ -425,7 +425,7 @@ private[sql] object JsonRDD extends Logging { val map = value.asInstanceOf[Map[String, Any]] map.map { case (k, v) => - (UTF8String(k), enforceCorrectType(v, valueType)) + (UTF8String.fromString(k), enforceCorrectType(v, valueType)) }.map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) @@ -434,7 +434,7 @@ private[sql] object JsonRDD extends Logging { } } - private def asRow(json: Map[String, Any], schema: StructType): Row = { + private def asRow(json: Map[String, Any], schema: StructType): InternalRow = { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala new file mode 100644 index 0000000000000..2be7c64612cd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.parquet.schema._ + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLConf} + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and + * vice versa. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @constructor + * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL + * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. + * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL + * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which + * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` + * described in Parquet format spec. + * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when + * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and + * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and + * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is + * backwards-compatible with these settings. If this argument is set to `false`, we fallback + * to old style non-standard behaviors. + */ +private[parquet] class CatalystSchemaConverter( + private val assumeBinaryIsString: Boolean, + private val assumeInt96IsTimestamp: Boolean, + private val followParquetFormatSpec: Boolean) { + + // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in + // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. + def this() = this( + assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + + def this(conf: SQLConf) = this( + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + followParquetFormatSpec = conf.followParquetFormatSpec) + + def this(conf: Configuration) = this( + assumeBinaryIsString = + conf.getBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), + assumeInt96IsTimestamp = + conf.getBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), + followParquetFormatSpec = + conf.getBoolean( + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertField(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertField(field), nullable = false) + + case REPEATED => + throw new AnalysisException( + s"REPEATED not supported outside LIST or MAP. Type: $field") + } + } + + StructType(fields) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertField(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotImplemented() = + throw new AnalysisException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new AnalysisException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + CatalystSchemaConverter.analysisRequire( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + field.getPrimitiveTypeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + field.getOriginalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + field.getOriginalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case TIMESTAMP_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT96 => + CatalystSchemaConverter.analysisRequire( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + field.getOriginalType match { + case UTF8 | ENUM => StringType + case null if assumeBinaryIsString => StringType + case null => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + field.getOriginalType match { + case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1, s"Invalid list type $field") + + val repeatedType = field.getType(0) + CatalystSchemaConverter.analysisRequire( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + + if (isElementType(repeatedType, field.getName)) { + ArrayType(convertField(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertField(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + CatalystSchemaConverter.analysisRequire( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + CatalystSchemaConverter.analysisRequire( + keyType.isPrimitive, + s"Map key type is expected to be a primitive type, but found: $keyType") + + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertField(keyType), + convertField(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new AnalysisException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + private def isElementType(repeatedType: Type, parentName: String) = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + def convert(catalystSchema: StructType): MessageType = { + Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + CatalystSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: !! This timestamp type is not specified in Parquet format spec !! + // However, Impala and older versions of Spark SQL use INT96 to store timestamps with + // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + case TimestampType => + Types.primitive(INT96, repetition).named(field.name) + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ===================================== + // Decimals (for Spark version <= 1.4.x) + // ===================================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType() if !followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. " + + s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + + "decimal precision and scale must be specified, " + + "and precision must be less than or equal to 18.") + + // ===================================== + // Decimals (follow Parquet format spec) + // ===================================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType.Unlimited if followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. Decimal precision and scale must be specified.") + + // =================================================== + // ArrayType and MapType (for Spark versions <= 1.4.x) + // =================================================== + + // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level + // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // group (LIST) { + // optional group bag { + // repeated element; + // } + // } + ConversionPatterns.listType( + repetition, + field.name, + Types + .buildGroup(REPEATED) + .addField(convertField(StructField("element", elementType, nullable))) + .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + // group (LIST) { + // repeated element; + // } + ConversionPatterns.listType( + repetition, + field.name, + convertField(StructField("element", elementType, nullable), REPEATED)) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ================================================== + // ArrayType and MapType (follow Parquet format spec) + // ================================================== + + case ArrayType(elementType, containsNull) if followParquetFormatSpec => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition).as(LIST) + .addField( + Types.repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition).as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + }.named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new AnalysisException(s"Unsupported data type $field.dataType") + } + } + + // Max precision of a decimal value stored in `numBytes` bytes + private def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Min byte counts needed to store decimals with various precisions + private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } +} + + +private[parquet] object CatalystSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + analysisRequire( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + + def analysisRequire(f: => Boolean, message: String): Unit = { + if (!f) { + throw new AnalysisException(message) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index 62c4e92ebec68..1551afd7b7bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -17,19 +17,35 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.Log import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +/** + * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder + * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the + * destination folder. This can be useful for data stored in S3, where directory operations are + * relatively expensive. + * + * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" + * property via Hadoop [[Configuration]]. Not that this property overrides + * "spark.sql.sources.outputCommitterClass". + * + * *NOTE* + * + * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's + * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are + * left * empty). + */ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - override def getWorkPath(): Path = outputPath + override def getWorkPath: Path = outputPath override def abortTask(taskContext: TaskAttemptContext): Unit = {} override def commitTask(taskContext: TaskAttemptContext): Unit = {} override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true @@ -46,13 +62,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) try { ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } + } catch { case e: Exception => + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) } } } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 85c2ce740fe52..ae7cbf0624dc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.parquet -import java.sql.Timestamp -import java.util.{TimeZone, Calendar} +import java.nio.ByteOrder -import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap} -import jodd.datetime.JDateTime +import org.apache.parquet.Preconditions import org.apache.parquet.column.Dictionary -import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ -import org.apache.spark.sql.parquet.timestamp.NanoTime +import org.apache.spark.unsafe.types.UTF8String /** * Collection of converters of Parquet types (group and primitive types) that @@ -77,7 +77,7 @@ private[sql] object CatalystConverter { // TODO: consider using Array[T] for arrays to avoid boxing of primitive types type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Row + type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] protected[parquet] def createConverter( @@ -221,7 +221,7 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { updateField(fieldIndex, value.getBytes) protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - updateField(fieldIndex, UTF8String(value)) + updateField(fieldIndex, UTF8String.fromBytes(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -238,7 +238,7 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * * @return */ - def getCurrentRecord: Row = throw new UnsupportedOperationException + def getCurrentRecord: InternalRow = throw new UnsupportedOperationException /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in @@ -266,14 +266,19 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a Timestamp value from a Parquet Int96Value */ - protected[parquet] def readTimestamp(value: Binary): Timestamp = { - CatalystTimestampConverter.convertToTimestamp(value) + protected[parquet] def readTimestamp(value: Binary): Long = { + Preconditions.checkArgument(value.length() == 12, "Must be 12 bytes") + val buf = value.toByteBuffer + buf.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) } } /** * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. * * @param schema The corresponding Catalyst schema in the form of a list of attributes. */ @@ -282,7 +287,7 @@ private[parquet] class CatalystGroupConverter( protected[parquet] val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[Row]) + protected[parquet] var buffer: ArrayBuffer[InternalRow]) extends CatalystConverter { def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = @@ -291,7 +296,7 @@ private[parquet] class CatalystGroupConverter( index, parent, current = null, - buffer = new ArrayBuffer[Row]( + buffer = new ArrayBuffer[InternalRow]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -307,13 +312,13 @@ private[parquet] class CatalystGroupConverter( override val size = schema.size - override def getCurrentRecord: Row = { + override def getCurrentRecord: InternalRow = { assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") // TODO: use iterators if possible // Note: this will ever only be called in the root converter when the record has been // fully processed. Therefore it will be difficult to use mutable rows instead, since // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericRow(current.toArray) + new GenericInternalRow(current.toArray) } override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -337,15 +342,15 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { assert(current != null) // there should be no empty groups - buffer.append(new GenericRow(current.toArray)) - parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + buffer.append(new GenericInternalRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]])) } } } /** * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his + * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his * converter is optimized for rows of primitive types (non-nested records). */ private[parquet] class CatalystPrimitiveRowConverter( @@ -371,7 +376,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override val parent = null // Should be only called in root group converter! - override def getCurrentRecord: Row = current + override def getCurrentRecord: InternalRow = current override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -401,7 +406,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.setInt(fieldIndex, value) override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.update(fieldIndex, value) + current.setInt(fieldIndex, value) override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -422,10 +427,10 @@ private[parquet] class CatalystPrimitiveRowConverter( current.update(fieldIndex, value.getBytes) override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - current.update(fieldIndex, UTF8String(value)) + current.update(fieldIndex, UTF8String.fromBytes(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, readTimestamp(value)) + current.setLong(fieldIndex, readTimestamp(value)) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -496,73 +501,6 @@ private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } -private[parquet] object CatalystTimestampConverter { - // TODO most part of this comes from Hive-0.14 - // Hive code might have some issues, so we need to keep an eye on it. - // Also we use NanoTime and Int96Values from parquet-examples. - // We utilize jodd to convert between NanoTime and Timestamp - val parquetTsCalendar = new ThreadLocal[Calendar] - def getCalendar: Calendar = { - // this is a cache for the calendar instance. - if (parquetTsCalendar.get == null) { - parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) - } - parquetTsCalendar.get - } - val NANOS_PER_SECOND: Long = 1000000000 - val SECONDS_PER_MINUTE: Long = 60 - val MINUTES_PER_HOUR: Long = 60 - val NANOS_PER_MILLI: Long = 1000000 - - def convertToTimestamp(value: Binary): Timestamp = { - val nt = NanoTime.fromBinary(value) - val timeOfDayNanos = nt.getTimeOfDayNanos - val julianDay = nt.getJulianDay - val jDateTime = new JDateTime(julianDay.toDouble) - val calendar = getCalendar - calendar.set(Calendar.YEAR, jDateTime.getYear) - calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) - calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) - - // written in command style - var remainder = timeOfDayNanos - calendar.set( - Calendar.HOUR_OF_DAY, - (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) - calendar.set( - Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) - calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) - val nanos = remainder % NANOS_PER_SECOND - val ts = new Timestamp(calendar.getTimeInMillis) - ts.setNanos(nanos.toInt) - ts - } - - def convertFromTimestamp(ts: Timestamp): Binary = { - val calendar = getCalendar - calendar.setTime(ts) - val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) - // Hive-0.14 didn't set hour before get day number, while the day number should - // has something to do with hour, since julian day number grows at 12h GMT - // here we just follow what hive does. - val julianDay = jDateTime.getJulianDayNumber - - val hour = calendar.get(Calendar.HOUR_OF_DAY) - val minute = calendar.get(Calendar.MINUTE) - val second = calendar.get(Calendar.SECOND) - val nanos = ts.getNanos - // Hive-0.14 would use hours directly, that might be wrong, since the day starts - // from 12h in Julian. here we just follow what hive does. - val nanosOfDay = nanos + second * NANOS_PER_SECOND + - minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + - hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR - NanoTime(julianDay, nanosOfDay).toBinary - } -} - /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see @@ -718,7 +656,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = UTF8String(value).asInstanceOf[NativeType] + buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType] elements += 1 } @@ -850,7 +788,7 @@ private[parquet] class CatalystStructConverter( // here we need to make sure to use StructScalaType // Note: we need to actually make a copy of the array since we // may be in a nested field - parent.updateField(index, new GenericRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(current.toArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 88ae88e9684c8..d57b789f5c1c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.io.Serializable import java.nio.ByteBuffer import com.google.common.io.BaseEncoding @@ -24,13 +25,15 @@ import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} +import org.apache.parquet.filter2.predicate.UserDefinedPredicate import org.apache.parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -41,6 +44,18 @@ private[sql] object ParquetFilters { }.reduceOption(FilterApi.and).map(FilterCompat.get) } + case class SetInFilter[T <: Comparable[T]]( + valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { + + override def keep(value: T): Boolean = { + value != null && valueSet.contains(value) + } + + override def canDrop(statistics: Statistics[T]): Boolean = false + + override def inverseCanDrop(statistics: Statistics[T]): Boolean = false + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -153,6 +168,29 @@ private[sql] object ParquetFilters { FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } + private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = { + case IntegerType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]])) + case LongType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]])) + case FloatType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]])) + case DoubleType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]])) + case StringType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(binaryColumn(n), + SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + case BinaryType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(binaryColumn(n), + SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]])))) + } + /** * Converts data sources filters to Parquet filter predicates. */ @@ -284,6 +322,9 @@ private[sql] object ParquetFilters { case Not(pred) => createFilter(pred).map(FilterApi.not) + case InSet(NamedExpression(name, dataType), valueSet) => + makeInSet.lift(dataType).map(_(name, valueSet)) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1e694f2feabee..b30fc171c0af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -46,15 +46,16 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.util.SerializableConfiguration /** * :: DeveloperApi :: * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. + * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``. */ private[sql] case class ParquetTableScan( attributes: Seq[Attribute], @@ -77,7 +78,7 @@ private[sql] case class ParquetTableScan( } }.toArray - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext @@ -113,16 +114,19 @@ private[sql] case class ParquetTableScan( .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set( - SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) + conf.setBoolean( + SQLConf.PARQUET_CACHE_METADATA.key, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) + + // Use task side metadata in parquet + conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( sc, classOf[FilteringParquetRowInputFormat], classOf[Void], - classOf[Row], + classOf[InternalRow], conf) if (requestedPartitionOrdinals.nonEmpty) { @@ -151,9 +155,9 @@ private[sql] case class ParquetTableScan( .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) if (primitiveRow) { - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. val row = iter.next()._2.asInstanceOf[SpecificMutableRow] @@ -170,12 +174,12 @@ private[sql] case class ParquetTableScan( } else { // Create a mutable row since we need to fill in values from partition columns. val mutableRow = new GenericMutableRow(outputSize) - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { // We are using CatalystGroupConverter and it returns a GenericRow. // Since GenericRow is not mutable, we just cast it to a Row. - val row = iter.next()._2.asInstanceOf[Row] + val row = iter.next()._2.asInstanceOf[InternalRow] var i = 0 while (i < row.size) { @@ -255,7 +259,7 @@ private[sql] case class InsertIntoParquetTable( /** * Inserts all rows into the Parquet file. */ - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -318,15 +322,15 @@ private[sql] case class InsertIntoParquetTable( * @param conf A [[org.apache.hadoop.conf.Configuration]]. */ private def saveAsHadoopFile( - rdd: RDD[Row], + rdd: RDD[InternalRow], path: String, conf: Configuration) { val job = new Job(conf) val keyType = classOf[Void] job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[Row]) + job.setOutputValueClass(classOf[InternalRow]) NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = sqlContext.sparkContext.newRddId() @@ -339,7 +343,7 @@ private[sql] case class InsertIntoParquetTable( .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 } - def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { + def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = { /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, context.attemptNumber) @@ -378,7 +382,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends org.apache.parquet.hadoop.ParquetOutputFormat[Row] { + extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,210 +435,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends org.apache.parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] override def createRecordReader( inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter - val readSupport: ReadSupport[Row] = new RowReadSupport() + val readSupport: ReadSupport[InternalRow] = new RowReadSupport() val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[Row]( + new ParquetRecordReader[InternalRow]( readSupport, filter) } else { - new ParquetRecordReader[Row](readSupport) - } - } - - // This is only a temporary solution sicne we need to use fileStatuses in - // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these - // two methods. - override def getSplits(jobContext: JobContext): JList[InputSplit] = { - // First set fileStatuses. - val statuses = listStatus(jobContext) - fileStatuses = statuses.map(file => file.getPath -> file).toMap - - super.getSplits(jobContext) - } - - // TODO Remove this method and related code once PARQUET-16 is fixed - // This method together with the `getFooters` method and the `fileStatuses` field are just used - // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17 - override def getSplits( - configuration: Configuration, - footers: JList[Footer]): JList[ParquetInputSplit] = { - - // Use task side strategy by default - val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) - val minSplitSize: JLong = - Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L)) - if (maxSplitSize < 0 || minSplitSize < 0) { - throw new ParquetDecodingException( - s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + - s" minSplitSize = $minSplitSize") - } - - // Uses strict type checking by default - val getGlobalMetaData = - classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) - getGlobalMetaData.setAccessible(true) - var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - - if (globalMetaData == null) { - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - return splits - } - - val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - val mergedMetadata = globalMetaData - .getKeyValueMetaData - .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata))) - - globalMetaData = new GlobalMetaData(globalMetaData.getSchema, - mergedMetadata, globalMetaData.getCreatedBy) - - val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( - new InitContext(configuration, - globalMetaData.getKeyValueMetaData, - globalMetaData.getSchema)) - - if (taskSideMetaData){ - logInfo("Using Task Side Metadata Split Strategy") - getTaskSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) - } else { - logInfo("Using Client Side Metadata Split Strategy") - getClientSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) + new ParquetRecordReader[InternalRow](readSupport) } - - } - - def getClientSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - import org.apache.parquet.filter2.compat.FilterCompat.Filter - import org.apache.parquet.filter2.compat.RowGroupFilter - - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - val filter: Filter = ParquetInputFormat.getFilter(configuration) - var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("org.apache.parquet.hadoop.ClientSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( - sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val fs = footer.getFile.getFileSystem(configuration) - val file = footer.getFile - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val parquetMetaData = footer.getParquetMetadata - val blocks = parquetMetaData.getBlocks - totalRowGroups = totalRowGroups + blocks.size - val filteredBlocks = RowGroupFilter.filterRowGroups( - filter, - blocks, - parquetMetaData.getFileMetaData.getSchema) - rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) - - if (!filteredBlocks.isEmpty){ - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } - splits.addAll( - generateSplits.invoke( - null, - filteredBlocks, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - } - - if (rowGroupsDropped > 0 && totalRowGroups > 0){ - val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt - logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " - + s"($percentDropped %) !") - } - else { - logInfo("There were no row groups that could be dropped due to filter predicates") - } - splits - - } - - def getTaskSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("org.apache.parquet.hadoop.TaskSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( - sys.error( - s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val file = footer.getFile - val fs = file.getFileSystem(configuration) - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - splits.addAll( - generateSplits.invoke( - null, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - - splits } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 89db408b1c382..df2a96dfeb619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.nio.{ByteOrder, ByteBuffer} import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration @@ -28,8 +29,10 @@ import org.apache.parquet.io.api._ import org.apache.parquet.schema.MessageType import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -37,12 +40,12 @@ import org.apache.spark.sql.types._ *@param root The root group converter for the record. */ private[parquet] class RowRecordMaterializer(root: CatalystConverter) - extends RecordMaterializer[Row] { + extends RecordMaterializer[InternalRow] { def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = this(CatalystConverter.createRootConverter(parquetSchema, attributes)) - override def getCurrentRecord: Row = root.getCurrentRecord + override def getCurrentRecord: InternalRow = root.getCurrentRecord override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] } @@ -50,13 +53,13 @@ private[parquet] class RowRecordMaterializer(root: CatalystConverter) /** * A `parquet.hadoop.api.ReadSupport` for Row objects. */ -private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { +private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { override def prepareForRead( conf: Configuration, stringMap: java.util.Map[String, String], fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[Row] = { + readContext: ReadContext): RecordMaterializer[InternalRow] = { log.debug(s"preparing for read with Parquet file schema $fileSchema") // Note: this very much imitates AvroParquet val parquetSchema = readContext.getRequestedSchema @@ -83,8 +86,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes( - parquetSchema, false, true) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -102,8 +104,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // If the parquet file is thrift derived, there is a good chance that // it will have the thrift class in metadata. val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter - .convertFromAttributes(requestedAttributes, isThriftDerived) + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) @@ -131,7 +132,7 @@ private[parquet] object RowReadSupport { /** * A `parquet.hadoop.api.WriteSupport` for Row objects. */ -private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { +private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { private[parquet] var writer: RecordConsumer = null private[parquet] var attributes: Array[Attribute] = null @@ -155,7 +156,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { log.debug(s"preparing for write with schema $attributes") } - override def write(record: Row): Unit = { + override def write(record: InternalRow): Unit = { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( @@ -197,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) + case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[Long]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case StringType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) - case ShortType => writer.addInteger(value.asInstanceOf[Short]) - case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) - case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -296,7 +296,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } // Scratch array used to write decimals as fixed-length binary - private val scratchBytes = new Array[Byte](8) + private[this] val scratchBytes = new Array[Byte](8) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) @@ -311,15 +311,22 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } - private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { - val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) - writer.addBinary(binaryNanoTime) + // array used to write Timestamp as Int96 (fixed-length binary) + private[this] val int96buf = new Array[Byte](12) + + private[parquet] def writeTimestamp(ts: Long): Unit = { + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts) + val buf = ByteBuffer.wrap(int96buf) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + writer.addBinary(Binary.fromByteArray(int96buf)) } } // Optimized for non-nested rows private[parquet] class MutableRowWriteSupport extends RowWriteSupport { - override def write(record: Row): Unit = { + override def write(record: InternalRow): Unit = { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( @@ -342,22 +349,21 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { private def consumeType( ctype: DataType, - record: Row, + record: InternalRow, index: Int): Unit = { ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) case StringType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index ba2a35b74ef82..e748bd7857bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,212 +29,17 @@ import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types._ -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) - private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false - } - - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean, - int96AsTimestamp: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 - if originalType == ParquetOriginalType.DATE => DateType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - throw new AnalysisException("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *
    - *
  • Primitive types are converter to the corresponding primitive type.
  • - *
  • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
      - *
    • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
    • - *
    • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
    • - *
    • Other group types are converted into a [[StructType]] with the corresponding - * field types.
  • - *
- * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case DateType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false } /** @@ -248,177 +53,29 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
    - *
  • Primitive types are converted into Parquet's primitive types.
  • - *
  • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
  • - *
  • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
  • - *
  • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
  • - *
- * Parquet's repetition level is generally set according to the following rule: - *
    - *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
  • - *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • - *
- * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val arraySchemaName = if (toThriftSchemaNames) { - name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX - } else { - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = false, - inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = true, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false, - toThriftSchemaNames) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => throw new AnalysisException(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString, isInt96AsTimestamp), - field.getRepetition != Repetition.REQUIRED)()) + def convertToAttributes( + parquetSchema: MessageType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { + val converter = new CatalystSchemaConverter( + isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) + converter.convert(parquetSchema).toAttributes } - def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { - checkSpecialCharacters(attributes) - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) - new MessageType("root", fields) + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val converter = new CatalystSchemaConverter() + converter.convert(StructType.fromAttributes(attributes)) } def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => throw new AnalysisException(s"Can convert $string to row") - } - } - - private def checkSpecialCharacters(schema: Seq[Attribute]) = { - // ,;{}()\n\t= and space character are special characters in Parquet schema - schema.map(_.name).foreach { name => - if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw new AnalysisException( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". - |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) - } + case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - checkSpecialCharacters(schema) + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) StructType.fromAttributes(schema).json } @@ -450,8 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) val metaData: FileMetaData = new FileMetaData( parquetSchema, extraMetadata, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 5dda440240e60..bc39fae2bcfde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.util.Try import com.google.common.base.Objects -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -34,16 +33,16 @@ import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.RDD._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.AnalysisException +import org.apache.spark.rdd.RDD._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( @@ -60,51 +59,22 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - private val recordWriter: RecordWriter[Void, Row] = { - val conf = context.getConfiguration + private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { - // When appending new Parquet files to an existing Parquet file directory, to avoid - // overwriting existing data files, we need to find out the max task ID encoded in these data - // file names. - // TODO Make this snippet a utility function for other data source developers - val maxExistingTaskId = { - // Note that `path` may point to a temporary location. Here we retrieve the real - // destination path from the configuration - val outputPath = new Path(conf.get("spark.sql.sources.output.path")) - val fs = outputPath.getFileSystem(conf) - - if (fs.exists(outputPath)) { - // Pattern used to match task ID in part file names, e.g.: - // - // part-r-00001.gz.parquet - // ^~~~~ - val partFilePattern = """part-.-(\d{1,}).*""".r - - fs.listStatus(outputPath).map(_.getPath.getName).map { - case partFilePattern(id) => id.toInt - case name if name.startsWith("_") => 0 - case name if name.startsWith(".") => 0 - case name => throw new AnalysisException( - s"Trying to write Parquet files to directory $outputPath, " + - s"but found items with illegal name '$name'.") - }.reduceOption(_ max _).getOrElse(0) - } else { - 0 - } - } - - new ParquetOutputFormat[Row]() { + new ParquetOutputFormat[InternalRow]() { // Here we override `getDefaultWorkFile` for two reasons: // - // 1. To allow appending. We need to generate output file name based on the max available - // task ID computed above. + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). // // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 - new Path(path, f"part-r-$split%05d$extension") + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } } @@ -112,7 +82,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row) + override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) override def close(): Unit = recordWriter.close(context) } @@ -208,15 +178,28 @@ private[sql] class ParquetRelation2( val committerClass = conf.getClass( - "spark.sql.parquet.output.committer.class", + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + conf.setClass( - SQLConf.OUTPUT_COMMITTER_CLASS, + SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[ParquetOutputCommitter]) + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + // TODO There's no need to use two kinds of WriteSupport // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and // complex types. @@ -251,8 +234,8 @@ private[sql] class ParquetRelation2( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = @@ -279,7 +262,7 @@ private[sql] class ParquetRelation2( initLocalJobFuncOpt = Some(initLocalJobFuncOpt), inputFormatClass = classOf[FilteringParquetRowInputFormat], keyClass = classOf[Void], - valueClass = classOf[Row]) { + valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -324,7 +307,7 @@ private[sql] class ParquetRelation2( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values + }.values.map(_.asInstanceOf[Row]) } } @@ -491,7 +474,7 @@ private[sql] object ParquetRelation2 extends Logging { ParquetTypesConverter.convertToString(dataSchema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString) + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) } /** This closure sets input paths at the driver side. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala deleted file mode 100644 index 4d5ed211ad0c0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet.timestamp - -import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.parquet.Preconditions -import org.apache.parquet.io.api.{Binary, RecordConsumer} - -private[parquet] class NanoTime extends Serializable { - private var julianDay = 0 - private var timeOfDayNanos = 0L - - def set(julianDay: Int, timeOfDayNanos: Long): this.type = { - this.julianDay = julianDay - this.timeOfDayNanos = timeOfDayNanos - this - } - - def getJulianDay: Int = julianDay - - def getTimeOfDayNanos: Long = timeOfDayNanos - - def toBinary: Binary = { - val buf = ByteBuffer.allocate(12) - buf.order(ByteOrder.LITTLE_ENDIAN) - buf.putLong(timeOfDayNanos) - buf.putInt(julianDay) - buf.flip() - Binary.fromByteBuffer(buf) - } - - def writeValue(recordConsumer: RecordConsumer): Unit = { - recordConsumer.addBinary(toBinary) - } - - override def toString: String = - "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" -} - -private[sql] object NanoTime { - def fromBinary(bytes: Binary): NanoTime = { - Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") - val buf = bytes.toByteBuffer - buf.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - new NanoTime().set(julianDay, timeOfDayNanos) - } - - def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { - new NanoTime().set(julianDay, timeOfDayNanos) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index c6a4dabbab05e..ce16e050c56ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -17,47 +17,48 @@ package org.apache.spark.sql.sources -import org.apache.spark.{Logging, SerializableWritable, TaskContext} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StringType, StructType, UTF8String} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.unsafe.types.UTF8String /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f)) :: Nil + (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => pruneFilterProject( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f)) :: Nil + (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan)) => pruneFilterProject( l, - projectList, + projects, filters, - (a, _) => t.buildScan(a)) :: Nil + (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray @@ -79,26 +80,27 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { buildPartitionedTableScan( l, - projectList, + projects, pushedFilters, t.partitionSpec.partitionColumns, selectedPartitions) :: Nil // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) => // See buildPartitionedTableScan for the reason that we need to create a shard // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) pruneFilterProject( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f, t.paths, confBroadcast)) :: Nil + (a, f) => + toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil case l @ LogicalRelation(t: TableScan) => - createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil + execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -118,14 +120,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters: Seq[Expression], partitionColumns: StructType, partitions: Array[Partition]) = { - val output = projections.map(_.toAttribute) val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] // Because we are creating one RDD per partition, we need to have a shared HadoopConf. // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => @@ -137,23 +138,23 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logicalRelation, projections, filters, - (requiredColumns, filters) => { + (columns: Seq[Attribute], filters) => { val partitionColNames = partitionColumns.fieldNames // Don't scan any partition columns to save I/O. Here we are being optimistic and // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. - val nonPartitionColumns = requiredColumns.filterNot(partitionColNames.contains) + val needed = columns.filterNot(a => partitionColNames.contains(a.name)) val dataRows = - relation.buildScan(nonPartitionColumns, filters, Array(dir), confBroadcast) + relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( relation.schema, - requiredColumns, + columns.map(_.name).toArray, partitionColNames, partitionValues, - dataRows) + toCatalystRDD(logicalRelation, needed, dataRows)) }) scan.execute() @@ -166,15 +167,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) } - createPhysicalRDD(logicalRelation.relation, output, unionedRows) + execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) } private def mergeWithPartitionValues( schema: StructType, requiredColumns: Array[String], partitionColumns: Array[String], - partitionValues: Row, - dataRows: RDD[Row]): RDD[Row] = { + partitionValues: InternalRow, + dataRows: RDD[InternalRow]): RDD[InternalRow] = { val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains) // If output columns contain any partition column(s), we need to merge scanned data @@ -185,13 +186,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val i = partitionColumns.indexOf(name) if (i != -1) { // If yes, gets column value from partition values. - (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => { + (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { mutableRow(ordinal) = partitionValues(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) - (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => { + (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { mutableRow(ordinal) = dataRow(i) } } @@ -200,7 +201,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Since we know for sure that this closure is serializable, we can avoid the overhead // of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally // this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718). - val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Row]) => { + val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { val dataTypes = requiredColumns.map(schema(_).dataType) val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => @@ -209,7 +210,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { mergers(i)(mutableRow, dataRow, i) i += 1 } - mutableRow.asInstanceOf[expressions.Row] + mutableRow.asInstanceOf[InternalRow] } } @@ -255,26 +256,26 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, - projectList: Seq[NamedExpression], + projects: Seq[NamedExpression], filterPredicates: Seq[Expression], - scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow]) = { pruneFilterProjectRaw( relation, - projectList, + projects, filterPredicates, (requestedColumns, pushedFilters) => { - scanBuilder(requestedColumns.map(_.name).toArray, selectFilters(pushedFilters).toArray) + scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray) }) } // Based on Catalyst expressions. protected def pruneFilterProjectRaw( relation: LogicalRelation, - projectList: Seq[NamedExpression], + projects: Seq[NamedExpression], filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = { + scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = { - val projectSet = AttributeSet(projectList.flatMap(_.references)) + val projectSet = AttributeSet(projects.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = filterPredicates.reduceLeftOption(expressions.And) @@ -282,38 +283,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }} - if (projectList.map(_.toAttribute) == projectList && - projectSet.size == projectList.size && + if (projects.map(_.toAttribute) == projects && + projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. val requestedColumns = - projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD(projects.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = createPhysicalRDD(relation.relation, requestedColumns, + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) - execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } - private[this] def createPhysicalRDD( - relation: BaseRelation, + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[this] def toCatalystRDD( + relation: LogicalRelation, output: Seq[Attribute], - rdd: RDD[Row]): SparkPlan = { - val converted = if (relation.needConversion) { + rdd: RDD[Row]): RDD[InternalRow] = { + if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd + rdd.map(_.asInstanceOf[InternalRow]) } - execution.PhysicalRDD(output, converted) + } + + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { + toCatalystRDD(relation, relation.output, rdd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index c4c99de5a38dc..8b2a45d8e970a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer @@ -25,12 +25,11 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -private[sql] case class Partition(values: Row, path: String) +private[sql] case class Partition(values: InternalRow, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) @@ -72,10 +71,11 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String): PartitionSpec = { + defaultPartitionName: String, + typeInference: Boolean): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val pathsWithPartitionValues = paths.flatMap { path => - parsePartition(path, defaultPartitionName).map(path -> _) + parsePartition(path, defaultPartitionName, typeInference).map(path -> _) } if (pathsWithPartitionValues.isEmpty) { @@ -84,7 +84,7 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. - val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues.map(_._2)) + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) // Creates the StructType which represents the partition columns. val fields = { @@ -99,7 +99,7 @@ private[sql] object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - Partition(Row.fromSeq(literals.map(_.value)), path.toString) + Partition(InternalRow.fromSeq(literals.map(_.value)), path.toString) } PartitionSpec(StructType(fields), partitions) @@ -124,7 +124,8 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartition( path: Path, - defaultPartitionName: String): Option[PartitionValues] = { + defaultPartitionName: String, + typeInference: Boolean): Option[PartitionValues] = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -137,7 +138,7 @@ private[sql] object PartitioningUtils { return None } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) maybeColumn.foreach(columns += _) chopped = chopped.getParent finished = maybeColumn.isEmpty || chopped.getParent == null @@ -153,7 +154,8 @@ private[sql] object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String): Option[(String, Literal)] = { + defaultPartitionName: String, + typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -164,7 +166,7 @@ private[sql] object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) Some(columnName -> literal) } } @@ -175,23 +177,22 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * FloatType -> DoubleType -> DecimalType.Unlimited -> + * DoubleType -> DecimalType.Unlimited -> * StringType * }}} */ - private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - // Column names of all partitions must match - val distinctPartitionsColNames = values.map(_.columnNames).distinct - - if (distinctPartitionsColNames.isEmpty) { + private[sql] def resolvePartitions( + pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { + if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { - assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") - s"Conflicting partition column names detected:\n$list" - }) + val distinctPartColNames = pathsWithPartitionValues.map(_._2.columnNames).distinct + assert( + distinctPartColNames.size == 1, + listConflictingPartitionColumns(pathsWithPartitionValues)) // Resolves possible type conflicts for each column + val values = pathsWithPartitionValues.map(_._2) val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => resolveTypeConflicts(values.map(_.literals(i))) @@ -204,26 +205,65 @@ private[sql] object PartitioningUtils { } } + private[sql] def listConflictingPartitionColumns( + pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { + val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct + + def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] = + seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }) + + val partColNamesToPaths = groupByKey(pathWithPartitionValues.map { + case (path, partValues) => partValues.columnNames -> path + }) + + val distinctPartColLists = distinctPartColNames.map(_.mkString(", ")).zipWithIndex.map { + case (names, index) => + s"Partition column name list #$index: $names" + } + + // Lists out those non-leaf partition directories that also contain files + val suspiciousPaths = distinctPartColNames.sortBy(_.length).flatMap(partColNamesToPaths) + + s"Conflicting partition column names detected:\n" + + distinctPartColLists.mkString("\n\t", "\n\t", "\n\n") + + "For partitioned table directories, data files should only live in leaf directories.\n" + + "And directories at the same level should have the same partition column name.\n" + + "Please check the following directories for unexpected files or " + + "inconsistent partition column names:\n" + + suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") + } + /** - * Converts a string to a `Literal` with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String): Literal = { - // First tries integral types - Try(Literal.create(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) - // Then falls back to fractional types - .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) - // Then falls back to string - .getOrElse { - if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(unescapePathName(raw), StringType) + defaultPartitionName: String, + typeInference: Boolean): Literal = { + if (typeInference) { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) + } + } + } else { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) } + } } private val upCastingOrder: Seq[DataType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index ebad0c1564ec0..2bdc341021256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{RDD, HadoopRDD} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.reflect.ClassTag @@ -65,7 +65,7 @@ private[spark] class SqlNewHadoopPartition( */ private[sql] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index bd3aad6631748..54c8eeb41a8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.sources -import java.util.Date +import java.util.{Date, UUID} import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -36,7 +35,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.util.SerializableConfiguration private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -48,8 +48,7 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame( - data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) + val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) // Invalidate the cache. @@ -59,6 +58,28 @@ private[sql] case class InsertIntoDataSource( } } +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ private[sql] case class InsertIntoHadoopFsRelation( @transient relation: HadoopFsRelation, @transient query: LogicalPlan, @@ -75,7 +96,8 @@ private[sql] case class InsertIntoHadoopFsRelation( val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match { + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => sys.error(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => @@ -86,11 +108,13 @@ private[sql] case class InsertIntoHadoopFsRelation( case (SaveMode.Ignore, exists) => !exists } + // If we are appending data to an existing dir. + val isAppend = (pathExists) && (mode == SaveMode.Append) if (doInsertion) { val job = new Job(hadoopConf) job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[Row]) + job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) // We create a DataFrame by applying the schema of relation to the data to make sure. @@ -103,23 +127,21 @@ private[sql] case class InsertIntoHadoopFsRelation( val project = Project( relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) - sqlContext.createDataFrame( - DataFrame(sqlContext, project).queryExecution.toRdd, - relation.schema, - needsConversion = false) + sqlContext.internalCreateDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema) } val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job), df) + insert(new DefaultWriterContainer(relation, job, isAppend), df) } else { val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) + relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) } } - Seq.empty[Row] + Seq.empty[InternalRow] } private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { @@ -132,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -141,22 +163,19 @@ private[sql] case class InsertIntoHadoopFsRelation( throw new SparkException("Job aborted.", cause) } - def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { // If anything below fails, we should abort the task. try { writerContainer.executorSideSetup(taskContext) - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = converter(iterator.next()).asInstanceOf[Row] - writerContainer.outputWriterForRow(row).write(row) - } + val converter = if (needsConversion) { + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] } else { - while (iterator.hasNext) { - val row = iterator.next() - writerContainer.outputWriterForRow(row).write(row) - } + r: InternalRow => r.asInstanceOf[Row] + } + while (iterator.hasNext) { + val row = converter(iterator.next()) + writerContainer.outputWriterForRow(row).write(row) } writerContainer.commitTask() @@ -201,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -210,7 +229,7 @@ private[sql] case class InsertIntoHadoopFsRelation( throw new SparkException("Job aborted.", cause) } - def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { // If anything below fails, we should abort the task. try { writerContainer.executorSideSetup(taskContext) @@ -218,24 +237,21 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionOutput, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - val convertedDataPart = converter(dataPart).asInstanceOf[Row] - writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) - } + val dataConverter: InternalRow => Row = if (needsConversion) { + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] } else { - val partitionSchema = StructType.fromAttributes(partitionOutput) - val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] - val dataPart = dataProj(row) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) - } + r: InternalRow => r.asInstanceOf[Row] + } + val partitionSchema = StructType.fromAttributes(partitionOutput) + val partConverter: InternalRow => Row = + CatalystTypeConverters.createToScalaConverter(partitionSchema) + .asInstanceOf[InternalRow => Row] + + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = partConverter(partitionProj(row)) + val dataPart = dataConverter(dataProj(row)) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) } writerContainer.commitTask() @@ -254,7 +270,7 @@ private[sql] case class InsertIntoHadoopFsRelation( inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { + if (codegenEnabled && expressions.forall(_.isThreadSafe)) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) @@ -264,12 +280,20 @@ private[sql] case class InsertIntoHadoopFsRelation( private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, - @transient job: Job) + @transient job: Job, + isAppend: Boolean) extends SparkHadoopMapReduceUtil with Logging with Serializable { - protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job)) + protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() // This is only used on driver side. @transient private val jobContext: JobContext = job @@ -297,12 +321,21 @@ private[sql] abstract class BaseWriterContainer( def driverSideSetup(): Unit = { setupIDs(0, 0, 0) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - // This preparation must happen before initializing output format and output committer, since - // their initialization involves the job configuration, which can be potentially decorated in - // `relation.prepareJobForWrite`. + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -327,30 +360,47 @@ private[sql] abstract class BaseWriterContainer( } private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - outputFormatClass.newInstance().getOutputCommitter(context) } } @@ -400,8 +450,9 @@ private[sql] abstract class BaseWriterContainer( private[sql] class DefaultWriterContainer( @transient relation: HadoopFsRelation, - @transient job: Job) - extends BaseWriterContainer(relation, job) { + @transient job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { @transient private var writer: OutputWriter = _ @@ -417,15 +468,16 @@ private[sql] class DefaultWriterContainer( assert(writer != null, "OutputWriter instance should have been initialized") writer.close() super.commitTask() - } catch { - case cause: Throwable => - super.abortTask() - throw new RuntimeException("Failed to commit task", cause) + } catch { case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will + // cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) } } override def abortTask(): Unit = { try { + // It's possible that the task fails before `writer` gets initialized if (writer != null) { writer.close() } @@ -439,8 +491,9 @@ private[sql] class DynamicPartitionWriterContainer( @transient relation: HadoopFsRelation, @transient job: Job, partitionColumns: Array[String], - defaultPartitionName: String) - extends BaseWriterContainer(relation, job) { + defaultPartitionName: String, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { // All output writers are created on executor side. @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _ @@ -469,21 +522,25 @@ private[sql] class DynamicPartitionWriterContainer( }) } - override def commitTask(): Unit = { - try { + private def clearOutputWriters(): Unit = { + if (outputWriters.nonEmpty) { outputWriters.values.foreach(_.close()) outputWriters.clear() + } + } + + override def commitTask(): Unit = { + try { + clearOutputWriters() super.commitTask() } catch { case cause: Throwable => - super.abortTask() throw new RuntimeException("Failed to commit task", cause) } } override def abortTask(): Unit = { try { - outputWriters.values.foreach(_.close()) - outputWriters.clear() + clearOutputWriters() } finally { super.abortTask() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 20afd60cb7767..b7095c8ead797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -26,7 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InternalRow} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ @@ -166,10 +166,14 @@ private[sql] class DDLParser( } ) - protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { case name => name } + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + protected lazy val pair: Parser[(String, String)] = optionName ~ stringLit ^^ { case k ~ v => (k, v) } @@ -404,7 +408,7 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[Row] = { + def run(sqlContext: SQLContext): Seq[InternalRow] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( @@ -421,7 +425,7 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sqlContext: SQLContext): Seq[InternalRow] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( @@ -434,7 +438,7 @@ private[sql] case class CreateTempTableUsingAsSelect( private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sqlContext: SQLContext): Seq[InternalRow] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -453,7 +457,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[Row] + Seq.empty[InternalRow] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 25887ba9a15b0..7005c7079af91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -27,11 +27,12 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.SerializableWritable -import org.apache.spark.sql.{Row, _} +import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** * ::DeveloperApi:: @@ -195,6 +196,8 @@ abstract class BaseRelation { * java.lang.String -> UTF8String * java.lang.Decimal -> Decimal * + * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] + * * Note: The internal representation is not stable across releases and thus data sources outside * of Spark SQL should leave this as true. * @@ -443,7 +446,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() } - p.copy(values = Row.fromSeq(castedValues)) + p.copy(values = InternalRow.fromSeq(castedValues)) } PartitionSpec(partitionSchema, castedPartitions) } @@ -491,9 +494,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { + val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference) } /** @@ -513,7 +518,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) @@ -571,21 +576,28 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio // Yeah, to workaround serialization... val dataSchema = this.dataSchema val codegenEnabled = this.codegenEnabled + val needConversion = this.needConversion val requiredOutput = requiredColumns.map { col => val field = dataSchema(col) BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - buildScan(inputFiles).mapPartitions { rows => - val buildProjection = if (codegenEnabled) { + val rdd = buildScan(inputFiles) + val converted = + if (needConversion) { + RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) + } else { + rdd.map(_.asInstanceOf[InternalRow]) + } + converted.mapPartitions { rows => + val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } - val mutableProjection = buildProjection() - rows.map(mutableProjection) + rows.map(r => mutableProjection(r).asInstanceOf[Row]) } } @@ -636,7 +648,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { buildScan(requiredColumns, filters, inputFiles) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 356a6100d2cf5..9fa394525d65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -38,7 +38,7 @@ class LocalSQLContext protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72e60d9aa75cb..eb3e913322062 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -94,7 +94,7 @@ class CachedTableSuite extends QueryTest { } test("too big for memory") { - val data = "*" * 10000 + val data = "*" * 1000 ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f5484f1368d1..88bb743ab0bc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -185,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + + checkAnswer( + ctx.sql("select isnull(null), isnull(1)"), + Row(true, false)) } test("isNotNull") { checkAnswer( nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + + checkAnswer( + ctx.sql("select isnotnull(null), isnotnull('a')"), + Row(false, true)) } test("===") { @@ -288,6 +296,22 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } + test("in") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".in(1, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 1)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"b".in("y", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "y")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + } + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: @@ -361,23 +385,6 @@ class ColumnExpressionSuite extends QueryTest { ) } - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key.asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key.desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(lit(null))), - (1 to 100).map(_ => Row(null)) - ) - } - test("upper") { checkAnswer( lowerCaseData.select(upper('l)), @@ -393,6 +400,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(upper(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT upper('aB'), ucase('cDe')"), + Row("AB", "CDE")) } test("lower") { @@ -410,6 +421,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(lower(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT lower('aB'), lcase('cDe')"), + Row("ab", "cde")) } test("monotonicallyIncreasingId") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 790b405c72697..b26d3ab253a1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -68,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala new file mode 100644 index 0000000000000..a4719a38de1d4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +class DataFrameDateTimeSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 53c2befb73702..abfd47c811ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -85,9 +85,109 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("constant functions") { + checkAnswer( + testData2.select(e()).limit(1), + Row(scala.math.E) + ) + checkAnswer( + testData2.select(pi()).limit(1), + Row(scala.math.Pi) + ) + checkAnswer( + ctx.sql("SELECT E()"), + Row(scala.math.E) + ) + checkAnswer( + ctx.sql("SELECT PI()"), + Row(scala.math.Pi) + ) + } + test("bitwiseNOT") { checkAnswer( testData2.select(bitwiseNOT($"a")), testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } + + test("bin") { + val df = Seq[(Integer, Integer)]((12, null)).toDF("a", "b") + checkAnswer( + df.select(bin("a"), bin("b")), + Row("1100", null)) + checkAnswer( + df.selectExpr("bin(a)", "bin(b)"), + Row("1100", null)) + } + + test("if function") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"), + Row("one", "not_one")) + } + + test("nvl function") { + checkAnswer( + ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + } + + test("misc md5 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(md5($"a"), md5("b")), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + + checkAnswer( + df.selectExpr("md5(a)", "md5(b)"), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + } + + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2("b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + + test("string length function") { + checkAnswer( + nullStrings.select(strlen($"s"), strlen("s")), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l, l) + }) + + checkAnswer( + nullStrings.selectExpr("length(s)"), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l) + }) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 051d13e9a544f..e1c6c706242d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ class DataFrameJoinSuite extends QueryTest { @@ -34,6 +35,15 @@ class DataFrameJoinSuite extends QueryTest { Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) } + test("join - join using multiple columns") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "int2")), + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -84,4 +94,20 @@ class DataFrameJoinSuite extends QueryTest { left.join(right, left("key") === right("key")), Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + + test("broadcast join hint") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + + // equijoin - should be converted into broadcast join + val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) + + // no join key -- should not be a broadcast join + val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) + + // planner should not crash without a join + broadcast(df1).queryExecution.executedPlan + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..765094da6bda7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Random + import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite @@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } } test("Frequent Items") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bb8621abe64ad..50d324c0686fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,7 +33,7 @@ class DataFrameSuite extends QueryTest { test("analysis error should be eagerly reported") { val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -47,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) testData.select('nonExistentName) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("dataframe toString") { @@ -70,7 +70,7 @@ class DataFrameSuite extends QueryTest { test("invalid plan toString, debug mode") { val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ @@ -83,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("access complex data") { @@ -134,6 +134,14 @@ class DataFrameSuite extends QueryTest { ) } + test("explode alias and star") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select(explode($"a").as("a"), $"*"), + Row("a", Seq("a"), 1) :: Nil) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), @@ -152,6 +160,12 @@ class DataFrameSuite extends QueryTest { testData.collect().filter(_.getInt(0) > 90).toSeq) } + test("filterExpr using where") { + checkAnswer( + testData.where("key > 50"), + testData.collect().filter(_.getInt(0) > 50).toSeq) + } + test("repartition") { checkAnswer( testData.select('key).repartition(10).select('key), @@ -293,7 +307,7 @@ class DataFrameSuite extends QueryTest { ) } - test("call udf in SQLContext") { + test("deprecated callUdf in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext sqlctx.udf.register("simpleUdf", (v: Int) => v * v) @@ -302,6 +316,15 @@ class DataFrameSuite extends QueryTest { Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) } + test("callUDF in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUDF", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUDF("simpleUDF", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( @@ -469,12 +492,84 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + + test("showString(negative)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(-1) === expectedAnswer) + } + + test("showString(0)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(0) === expectedAnswer) + } + + test("showString: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = """+---------+---------+ + || _1| _2| + |+---------+---------+ + ||[1, 2, 3]|[1, 2, 3]| + ||[2, 3, 4]|[2, 3, 4]| + |+---------+---------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + + test("showString: minimum column width") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = """+---+---+ + || _1| _2| + |+---+---+ + || 1| 1| + || 2| 2| + |+---+---+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("SPARK-7319 showString") { val expectedAnswer = """+---+-----+ ||key|value| |+---+-----+ || 1| 1| |+---+-----+ + |only showing top 1 row |""".stripMargin assert(testData.select($"*").showString(1) === expectedAnswer) } @@ -497,13 +592,13 @@ class DataFrameSuite extends QueryTest { test("SPARK-6899") { val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") + ctx.setConf(SQLConf.CODEGEN_ENABLED, true) try{ checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ffd26c4f5a7c2..20390a5544304 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -95,14 +95,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } @@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -127,7 +127,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -416,7 +416,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("CACHE TABLE testData") val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -424,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -432,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 0a38af2b4c889..d6331aa4ff09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.{log => logarithm} private object MathExpressionsTestData { @@ -151,20 +152,31 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(tanh, math.tanh) } - test("toDeg") { + test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) + checkAnswer( + ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + ) } - test("toRad") { + test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) + checkAnswer( + ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + ) } test("cbrt") { testOneToOneMathFunction(cbrt, math.cbrt) } - test("ceil") { + test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) + checkAnswer( + ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + Row(0.0, 1.0, 2.0)) } test("floor") { @@ -183,12 +195,34 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(expm1, math.expm1) } - test("signum") { + test("signum / sign") { testOneToOneMathFunction[Double](signum, math.signum) + + checkAnswer( + ctx.sql("SELECT sign(10), signum(-11)"), + Row(1, -1)) } - test("pow") { + test("pow / power") { testTwoToOneMathFunction(pow, pow, math.pow) + + checkAnswer( + ctx.sql("SELECT pow(1, 2), power(2, 1)"), + Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) + ) + } + + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) } test("hypot") { @@ -199,8 +233,12 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(atan2, atan2, math.atan2) } - test("log") { + test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) + checkAnswer( + ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) + ) } test("log10") { @@ -211,4 +249,60 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("binary log") { + val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") + checkAnswer( + df.select(org.apache.spark.sql.functions.log("a"), + org.apache.spark.sql.functions.log(2.0, "a"), + org.apache.spark.sql.functions.log("b")), + Row(math.log(123), math.log(123) / math.log(2), null)) + + checkAnswer( + df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + Row(math.log(123), math.log(123) / math.log(2), null)) + } + + test("abs") { + val input = + Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) + checkAnswer( + input.toDF("key", "value").select(abs($"key").alias("a")).sort("a"), + input.map(pair => Row(pair._2))) + + checkAnswer( + input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), + input.map(pair => Row(pair._2))) + } + + test("log2") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + + checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + } + + test("sqrt") { + val df = Seq((1, 4)).toDF("a", "b") + checkAnswer( + df.select(sqrt("a"), sqrt("b")), + Row(1.0, 2.0)) + + checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) + } + + test("negative") { + checkAnswer( + ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + Row(-1, 0, 1)) + } + + test("positive") { + val df = Seq((1, -1, "abc")).toDF("a", "b", "c") + checkAnswer(df.selectExpr("positive(a)"), Row(1)) + checkAnswer(df.selectExpr("positive(b)"), Row(-1)) + checkAnswer(df.selectExpr("positive(c)"), Row("abc")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala new file mode 100644 index 0000000000000..2e33777f14adc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf._ + +class SQLConfEntrySuite extends SparkFunSuite { + + val conf = new SQLConf + + test("intConf") { + val key = "spark.sql.SQLConfEntrySuite.int" + val confEntry = SQLConfEntry.intConf(key) + assert(conf.getConf(confEntry, 5) === 5) + + conf.setConf(confEntry, 10) + assert(conf.getConf(confEntry, 5) === 10) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5) === 20) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be int, but was abc") + } + + test("longConf") { + val key = "spark.sql.SQLConfEntrySuite.long" + val confEntry = SQLConfEntry.longConf(key) + assert(conf.getConf(confEntry, 5L) === 5L) + + conf.setConf(confEntry, 10L) + assert(conf.getConf(confEntry, 5L) === 10L) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5L) === 20L) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be long, but was abc") + } + + test("booleanConf") { + val key = "spark.sql.SQLConfEntrySuite.boolean" + val confEntry = SQLConfEntry.booleanConf(key) + assert(conf.getConf(confEntry, false) === false) + + conf.setConf(confEntry, true) + assert(conf.getConf(confEntry, false) === true) + + conf.setConfString(key, "true") + assert(conf.getConfString(key, "false") === "true") + assert(conf.getConfString(key) === "true") + assert(conf.getConf(confEntry, false) === true) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be boolean, but was abc") + } + + test("doubleConf") { + val key = "spark.sql.SQLConfEntrySuite.double" + val confEntry = SQLConfEntry.doubleConf(key) + assert(conf.getConf(confEntry, 5.0) === 5.0) + + conf.setConf(confEntry, 10.0) + assert(conf.getConf(confEntry, 5.0) === 10.0) + + conf.setConfString(key, "20.0") + assert(conf.getConfString(key, "5.0") === "20.0") + assert(conf.getConfString(key) === "20.0") + assert(conf.getConf(confEntry, 5.0) === 20.0) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be double, but was abc") + } + + test("stringConf") { + val key = "spark.sql.SQLConfEntrySuite.string" + val confEntry = SQLConfEntry.stringConf(key) + assert(conf.getConf(confEntry, "abc") === "abc") + + conf.setConf(confEntry, "abcd") + assert(conf.getConf(confEntry, "abc") === "abcd") + + conf.setConfString(key, "abcde") + assert(conf.getConfString(key, "abc") === "abcde") + assert(conf.getConfString(key) === "abcde") + assert(conf.getConf(confEntry, "abc") === "abcde") + } + + test("enumConf") { + val key = "spark.sql.SQLConfEntrySuite.enum" + val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + assert(conf.getConf(confEntry) === "a") + + conf.setConf(confEntry, "b") + assert(conf.getConf(confEntry) === "b") + + conf.setConfString(key, "c") + assert(conf.getConfString(key, "a") === "c") + assert(conf.getConfString(key) === "c") + assert(conf.getConf(confEntry) === "c") + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "d") + } + assert(e.getMessage === s"The value of $key should be one of a, b, c, but was d") + } + + test("stringSeqConf") { + val key = "spark.sql.SQLConfEntrySuite.stringSeq" + val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", + defaultValue = Some(Nil)) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) + + conf.setConf(confEntry, Seq("a", "b", "c", "d")) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d")) + + conf.setConfString(key, "a,b,c,d,e") + assert(conf.getConfString(key, "a,b,c") === "a,b,c,d,e") + assert(conf.getConfString(key) === "a,b,c,d,e") + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 76d0dd1744a41..75791e9d53c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -75,6 +75,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") + assert(ctx.conf.numShufflePartitions === 10) + } + + test("invalid conf value") { + ctx.conf.clear() + val e = intercept[IllegalArgumentException] { + ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + } + assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5babc4332cc77..82dc0e9ce5132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate @@ -38,6 +40,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { import sqlContext.implicits._ import sqlContext.sql + test("having clause") { + Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") + checkAnswer( + sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), + Row("one", 6) :: Row("three", 3) :: Nil) + } + + test("SPARK-8010: promote numeric to string") { + val df = Seq((1, 1)).toDF("key", "value") + df.registerTempTable("src") + val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") + + checkAnswer(queryCaseWhen, Row("1.0") :: Nil) + checkAnswer(queryCoalesce, Row("1") :: Nil) + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), @@ -116,6 +135,31 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) } + test("SPARK-7158 collect and take return different results") { + import java.util.UUID + + val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") + // we except the id is materialized once + val idUDF = udf(() => UUID.randomUUID().toString) + + val dfWithId = df.withColumn("id", idUDF()) + // Make a new DataFrame (actually the same reference to the old one) + val cached = dfWithId.cache() + // Trigger the cache + val d0 = dfWithId.collect() + val d1 = cached.collect() + val d2 = cached.collect() + + // Since the ID is only materialized once, then all of the records + // should come from the cache, not by re-computing. Otherwise, the ID + // will be different + assert(d0.map(_(0)) === d2.map(_(0))) + assert(d0.map(_(1)) === d2.map(_(1))) + + assert(d1.map(_(0)) === d2.map(_(0))) + assert(d1.map(_(1)) === d2.map(_(1))) + } + test("grouping on nested fields") { sqlContext.read.json(sqlContext.sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) @@ -145,21 +189,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } - test("SPARK-3176 Added Parser of SQL ABS()") { - checkAnswer( - sql("SELECT ABS(-1.3)"), - Row(1.3)) - checkAnswer( - sql("SELECT ABS(0.0)"), - Row(0.0)) - checkAnswer( - sql("SELECT ABS(2.5)"), - Row(2.5)) - } - test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) @@ -256,7 +288,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(0, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -314,6 +346,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3173 Timestamp support in the parser") { + (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) @@ -449,41 +483,41 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("sorting") { val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("external sorting") { val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("SPARK-6927 sorting with codegen on") { val externalbefore = sqlContext.conf.externalSortEnabled val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) try{ sortTest() } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } test("SPARK-6927 external sorting with codegen on") { val externalbefore = sqlContext.conf.externalSortEnabled val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) try { sortTest() } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } @@ -664,7 +698,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) :: @@ -877,25 +911,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Row(s"$testKey=$testVal"), - Row(s"${testKey + testKey}=${testVal + testVal}")) + Row(testKey, testVal), + Row(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) checkAnswer( sql(s"SET $nonexistentKey"), - Row(s"$nonexistentKey=") + Row(nonexistentKey, "") ) sqlContext.conf.clear() } @@ -1309,12 +1343,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { @@ -1332,9 +1366,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { @@ -1345,6 +1379,51 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + test("SPARK-6583 order by aggregated function") { + Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2) + .toDF("a", "b").registerTempTable("orderByData") + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) + + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + 1 + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + } + test("SPARK-7952: fix the equality check between boolean and numeric types") { withTempTable("t") { // numeric field i, boolean field j, result of i = j, result of i <=> j @@ -1364,4 +1443,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) } } + + test("SPARK-7067: order by queries for complex ExtractValue chain") { + withTempTable("t") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ece3d6fdf2af5..4cb5ba2f0d5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ case class ReflectData( stringField: String, @@ -128,16 +127,16 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Seq(data).toDF().registerTempTable("reflectComplexData") assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === - new GenericRow(Array[Any]( + Row( Seq(1, 2, 3), Seq(1, 2, null), Map(1 -> 10L, 2 -> 20L), Map(1 -> 10L, 2 -> 20L, 3 -> null), - new GenericRow(Array[Any]( + Row( Seq(10, 20, 30), Seq(10, 20, null), Map(10 -> 100L, 20 -> 200L), Map(10 -> 100L, 20 -> 200L, 30 -> null), - new GenericRow(Array[Any](null, "abc"))))))) + Row(null, "abc")))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 725a18bfae3a7..207d7a352c7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ @@ -174,12 +173,6 @@ object TestData { "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) - case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => - TimestampField(new Timestamp(i)) - }) - timestamps.toDF().registerTempTable("timestamps") - case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 064c040d2b771..703a34c47ec20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -25,6 +25,48 @@ class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } + test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 339e719f39f16..9bd7b221e93f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -18,25 +18,29 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, + InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) + testColumnStats(classOf[FixedDecimalColumnStats], + FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row): Unit = { + initialStatistics: InternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index a1e76eaa982cc..4d46a657056e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,26 +18,27 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp -import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.serializer.KryoRegistrator +import com.esotericsoftware.kryo.{Kryo, Serializer} import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -59,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, new Timestamp(0L), 12) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -87,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -106,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -119,7 +120,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - UTF8String(bytes) + UTF8String.fromBytes(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -196,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { testColumnType[T, T#InternalType](columnType, putter, getter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 75d993e563e06..d9861339739c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import scala.collection.immutable.HashSet import scala.util.Random - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} +import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} +import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -41,21 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => UTF8String(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() + case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => - val timestamp = new Timestamp(Random.nextLong()) - timestamp.setNanos(Random.nextInt(999999999)) - timestamp + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) @@ -81,9 +76,9 @@ object ColumnarTestUtils { def makeRandomRow( head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail) + tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = { + def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index fa3b8144c086e..01bc23277fa88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { @@ -92,15 +91,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-2729 regression: timestamp data type") { + val timestamps = (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time") + timestamps.registerTempTable("timestamps") + checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 2a6e0c376551a..9eaa769846088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb4e9f1eb7f46..17e9ae464bcc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 6545c6b314a4c..2c0879927a129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -32,7 +32,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString @@ -41,14 +41,14 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } before { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 20d65a74e3b7a..f606e2133bedc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ @@ -32,7 +32,7 @@ class BooleanBitSetSuite extends SparkFunSuite { // ------------- val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) - val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) val values = rows.map(_(0)) rows.foreach(builder.appendFrom(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 45a7e8fe68f72..3dd24130af81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,16 +18,15 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SQLConf, execution} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { @@ -64,7 +63,7 @@ class PlannerSuite extends SparkFunSuite { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) @@ -120,12 +119,12 @@ class PlannerSuite extends SparkFunSuite { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") @@ -140,6 +139,12 @@ class PlannerSuite extends SparkFunSuite { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala new file mode 100644 index 0000000000000..a1e3ca11b1ad9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SortSuite extends SparkPlanTest { + + // This test was originally added as an example of how to use [[SparkPlanTest]]; + // it's not designed to be a comprehensive test of ExternalSort. + test("basic sorting using ExternalSort") { + + val input = Seq( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala new file mode 100644 index 0000000000000..13f3be8ca28d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} + +/** + * Base class for writing tests for individual physical operators. For an example of how this + * class's test helper methods can be used, see [[SortSuite]]. + */ +class SparkPlanTest extends SparkFunSuite { + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + TestSQLContext.implicits.localSeqToDataFrameHolder(data) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} + +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Option[String] = { + + val outputPlan = planFunction(input.queryExecution.sparkPlan) + + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { + case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap + + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + converted.sortBy(_.toString()) + } + + val sparkAnswer: Seq[Row] = try { + resolvedPlan.executeCollect().toSeq + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | Results do not match for Spark plan: + | $outputPlan + | == Results == + | ${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + + None + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 5290c28cfca02..71db6a2159857 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow} import org.apache.spark.util.collection.CompactBuffer @@ -26,37 +26,37 @@ class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { - override def apply(row: Row): Row = row + override def apply(row: InternalRow): InternalRow = row } test("GeneralHashedRelation") { - val data = Array(Row(0), Row(1), Row(2), Row(2)) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(InternalRow(10)) === null) - val data2 = CompactBuffer[Row](data(2)) + val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) == data2) } test("UniqueKeyHashedRelation") { - val data = Array(Row(0), Row(1), Row(2)) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] assert(uniqHashed.getValue(data(0)) == data(0)) assert(uniqHashed.getValue(data(1)) == data(1)) assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(Row(10)) == null) + assert(uniqHashed.getValue(InternalRow(10)) == null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 49d348c3ed21b..69ab1c292d221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -326,7 +326,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) } test("test DATE types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index d889c7be17ce7..8204a584179bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -25,7 +25,7 @@ import org.scalactic.Tolerance._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -76,22 +76,28 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion( Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion( - DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) + DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(3601000), + enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), + enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(10801000), + enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -1073,14 +1079,14 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val useStreaming = ctx.conf.useJacksonStreamingAPI val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ - for (useStreaming <- List("true", "false")) { + for (useStreaming <- List(true, false)) { ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath @@ -1099,4 +1105,8 @@ class JsonSuite extends QueryTest with TestJsonData { } } + test("SPARK-8093 Erase empty structs") { + val emptySchema = InferSchema(emptyRecords, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b6a6a8dc6a63c..eb62066ac6430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -189,5 +189,14 @@ trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def emptyRecords: RDD[String] = + ctx.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a": {}}""" :: + """{"a": {"b": {}}}""" :: + """{"b": [{"c": {}}]}""" :: + """]""" :: Nil) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 17f5f9a491e6b..a2763c78b6450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -51,7 +51,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -314,17 +314,17 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6554: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) @@ -343,17 +343,17 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6742: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 2b6a27032e637..7b16eba00d6fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -23,21 +23,22 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -93,12 +94,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = sqlContext.sparkContext .parallelize(0 to 1000) @@ -136,7 +136,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { def makeDateRDD(): DataFrame = sqlContext.sparkContext .parallelize(0 to 1000) - .map(i => Tuple1(DateUtils.toJavaDate(i))) + .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() .select($"_1") @@ -157,6 +157,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { checkParquetFile(data) } + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => @@ -196,7 +201,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNulls :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } } @@ -209,7 +214,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNones :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } } @@ -230,7 +235,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (0 until 10).map(i => (i, i.toString)) def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) @@ -379,6 +384,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { + val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { @@ -393,23 +400,55 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val fs = path.getFileSystem(configuration) assert(!fs.exists(path)) } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } - finally { - configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.parquet.hadoop.ParquetOutputCommitter") + } + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + withTempPath { dir => + val clonedConf = new Configuration(configuration) + + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) + + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[BogusParquetOutputCommitter].getCanonicalName) + + try { + val message = intercept[SparkException] { + sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(message === "Intentional exception for testing purposes") + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } } } } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} + class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) } test("SPARK-6330 regression test") { @@ -429,10 +468,10 @@ class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfter private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index d9a010a9815a1..d0ebb11b063f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -26,11 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} +import org.apache.spark.sql._ +import org.apache.spark.unsafe.types.UTF8String // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -48,24 +50,24 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) } check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5f, FloatType)) + check("1.5", Literal.create(1.5, DoubleType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName)) + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName).get + parsePartition(new Path(path), defaultPartitionName, true).get }.getMessage assert(message.contains(expected)) @@ -83,13 +85,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5f, FloatType))) + Literal.create(1.5, DoubleType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5f, FloatType))) + ArrayBuffer(Literal.create(1.5, DoubleType))) }) check("file:///", None) @@ -105,7 +107,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partitions") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) } check(Seq( @@ -114,18 +116,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructType(Seq( StructField("a", IntegerType), StructField("b", StringType))), - Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + Seq(Partition(InternalRow(10, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) check(Seq( "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/a=10.5/b=hello"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), - Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) check(Seq( "hdfs://host:9000/path/_temporary", @@ -140,11 +145,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), - Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=20", @@ -154,19 +161,102 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", IntegerType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), - Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + Partition(InternalRow(10, UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( - Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), - Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(10.5, null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + + test("parse partitions with type inference disabled") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq(Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), null), + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(UTF8String.fromString("10.5"), null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) check(Seq( s"hdfs://host:9000/path1", @@ -448,4 +538,49 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index de0107a361815..fafad67dde3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.{QueryTest, Row, SQLConf} /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ import sqlContext.sql test("simple select queries") { @@ -130,11 +128,11 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } @@ -142,10 +140,10 @@ class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAn private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 171a656f0e01e..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,26 +24,109 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { + val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes( - ScalaReflection.attributesFor[T], isThriftDerived) - val expected = MessageTypeParser.parseMessageType(messageType) + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) actual.checkContains(expected) expected.checkContains(actual) } } - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( "basic types", """ |message root { @@ -54,9 +137,10 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | required double _5; | optional binary _6; |} - """.stripMargin) + """.stripMargin, + binaryAsString = false) - testSchema[(Byte, Short, Int, Long, java.sql.Date)]( + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -68,27 +152,87 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", + testSchemaInference[Tuple1[String]]( + "string", """ |message root { | optional binary _1 (UTF8); |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} """.stripMargin) - testSchema[Tuple1[Seq[Int]]]( - "array", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) - testSchema[Tuple1[Map[Int, String]]]( - "map", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", """ |message root { | optional group _1 (MAP) { @@ -100,7 +244,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", """ |message root { @@ -109,20 +253,21 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | optional binary _2 (UTF8); | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", """ |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { | required int32 key; | optional group value { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array { + | optional group element { | required int32 _1; | required double _2; | } @@ -134,43 +279,76 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", """ |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP) { + | repeated group key_value { | required int32 key; - | optional double value; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } | } | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", """ |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; | } | } |} - """.stripMargin, isThriftDerived = true) + """.stripMargin, + followParquetFormatSpec = true) + + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} +class ParquetSchemaSuite extends ParquetSchemaTest { test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -180,10 +358,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" // scalastyle:off - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) @@ -277,4 +452,465 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 51d22b6a1378a..54e1efb6e36e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class DDLScanSource extends RelationProvider { override def createRelation( @@ -56,9 +58,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo ) )) + override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { - sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + sqlContext.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 3f77960d09246..00cc7d5ea580f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -27,7 +27,7 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. protected implicit lazy val caseInsensitiveContext = { val ctx = new SQLContext(TestSQLContext.sparkContext) - ctx.setConf(SQLConf.CASE_SENSITIVE, "false") + ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5d4ecd810862c..2c916f3322b6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.sources -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class DefaultSource extends SimpleScanSource @@ -47,6 +50,10 @@ class AllDataTypesScanSource extends SchemaRelationProvider { sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { + // Check that weird parameters are passed correctly. + parameters("option_with_underscores") + parameters("option.with.dots") + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) } } @@ -60,10 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema + override def needConversion: Boolean = false + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - Row( - s"str_$i", + InternalRow( + UTF8String.fromString(s"str_$i"), s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -72,17 +81,18 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - new java.math.BigDecimal(i), - new java.math.BigDecimal(i), - new Date(1970, 1, 1), - new Timestamp(20000 + i), - s"varchar_$i", + Decimal(new java.math.BigDecimal(i)), + Decimal(new java.math.BigDecimal(i)), + DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), + DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), + UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), - Seq(Map(s"str_$i" -> Row(i.toLong))), - Map(i -> i.toString), - Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), - Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) + Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), + Map(i -> UTF8String.fromString(i.toString)), + Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), + InternalRow(i, UTF8String.fromString(i.toString)), + InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), + InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } } } @@ -121,7 +131,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) @@ -152,7 +164,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) } @@ -354,7 +368,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | from '1', - | to '10' + | to '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ac4a00a6f3dac..fa01823e9417c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -37,11 +37,11 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConf) + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConf(key, value) + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) case (key, None) => sqlContext.conf.unsetConf(key) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index c9da25253e13f..700d994bb6a83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -153,9 +153,9 @@ object HiveThriftServer2 extends Logging { val sessionList = new mutable.LinkedHashMap[String, SessionInfo] val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) var totalRunning = 0 override def onJobStart(jobStart: SparkListenerJobStart): Unit = { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e071103df925c..e8758887ff3a2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -219,7 +219,7 @@ private[hive] class SparkExecuteStatementOperation( result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) => sessionToActivePool(parentSession.getSessionHandle) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123bf..1d41c46131828 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 178bd1f5cb164..301aa5a6411e2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -113,8 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } @@ -238,7 +238,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // first session, we get the default value of the session status { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() defaultV1 = rs1.getString(1) assert(defaultV1 != "200") @@ -256,19 +256,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { { statement => val queries = Seq( - s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291", + s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291", "SET hive.cli.print.header=true" ) queries.map(statement.execute) - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() - assert("spark.sql.shuffle.partitions=291" === rs1.getString(1)) + assert("spark.sql.shuffle.partitions" === rs1.getString(1)) + assert("291" === rs1.getString(2)) rs1.close() val rs2 = statement.executeQuery("SET hive.cli.print.header") rs2.next() - assert("hive.cli.print.header=true" === rs2.getString(1)) + assert("hive.cli.print.header" === rs2.getString(1)) + assert("true" === rs2.getString(2)) rs2.close() }, @@ -276,7 +278,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // default value { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() assert(defaultV1 === rs1.getString(1)) rs1.close() @@ -404,8 +406,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 4c9fab7ef6136..806240e6de458 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -22,12 +22,13 @@ import scala.util.Random import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest.{BeforeAndAfterAll, Matchers} import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite extends HiveThriftJdbcTest @@ -40,7 +41,9 @@ class UISeleniumSuite override def mode: ServerMode.Value = ServerMode.binary override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } super.beforeAll() } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 048f78b4daa8d..f88e62763ca70 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -47,17 +47,17 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -252,7 +252,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part14.*", // These work alone but fail when run with other tests... // the answer is sensitive for jdk version - "udf_java_method" + "udf_java_method", + + // Spark SQL use Long for TimestampType, lose the precision under 100ns + "timestamp_1", + "timestamp_2" ) /** @@ -795,8 +799,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", - "timestamp_1", - "timestamp_2", "timestamp_3", "timestamp_comparison", "timestamp_lazy", @@ -817,19 +819,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - "udf7", + // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - "udf_acos", + // "udf_acos", turn this on after we figure out null vs nan vs infinity "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - "udf_asin", + // "udf_asin", turn this on after we figure out null vs nan vs infinity "udf_atan", "udf_avg", "udf_bigint", @@ -917,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - "udf_round_3", + // "udf_round_3", TODO: FIX THIS failed due to cast exception "udf_rpad", "udf_rtrim", "udf_second", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 934452fe579a1..31a49a3683338 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -526,8 +526,14 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with | rows between 2 preceding and 2 following); """.stripMargin, reset = false) + // collect_set() output array in an arbitrary order, hence causes different result + // when running this test suite under Java 7 and 8. + // We change the original sql query a little bit for making the test suite passed + // under different JDK createQueryTest("windowing.q -- 20. testSTATs", """ + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( |select p_mfgr,p_name, p_size, |stddev(p_retailprice) over w1 as sdev, |stddev_pop(p_retailprice) over w1 as sdev_pop, @@ -538,6 +544,8 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) createQueryTest("windowing.q -- 21. testDISTs", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 65d070bd3cbde..f458567e5d7ea 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) super.afterAll() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b8f294c262af7..b91242af2d155 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,15 +21,13 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.spark.sql.catalyst.ParserDialect - import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution @@ -39,13 +37,15 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.SQLConf.SQLConfEntry._ +import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -70,13 +70,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { import HiveContext._ + println("create HiveContext") + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive * SerDe. */ - protected[sql] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) /** * When true, also tries to merge possibly different but compatible Parquet schemas in different @@ -85,7 +86,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. */ protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) /** * When true, a table created by a Hive CTAS statement (no USING clause) will be @@ -99,8 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format * and no SerDe is specified (no ROW FORMAT SERDE clause). */ - protected[sql] def convertCTAS: Boolean = - getConf("spark.sql.hive.convertCTAS", "false").toBoolean + protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) /** * The version of the hive client that will be used to communicate with the metastore. Note that @@ -118,8 +118,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * option is only valid when using the execution version of Hive. * - maven - download the correct version of hive on demand from maven. */ - protected[hive] def hiveMetastoreJars: String = - getConf(HIVE_METASTORE_JARS, "builtin") + protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) /** * A comma separated list of class prefixes that should be loaded using the classloader that @@ -129,11 +128,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * custom appenders that are used by log4j. */ protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = - getConf("spark.sql.hive.metastore.sharedPrefixes", jdbcPrefixes) - .split(",").filterNot(_ == "") - - private def jdbcPrefixes = Seq( - "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc").mkString(",") + getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") /** * A comma separated list of class prefixes that should explicitly be reloaded for each version @@ -141,14 +136,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * prefix that typically would be shared (i.e. org.apache.spark.*) */ protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = - getConf("spark.sql.hive.metastore.barrierPrefixes", "") - .split(",").filterNot(_ == "") + getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") /* * hive thrift server use background spark sql thread pool to execute sql queries */ - protected[hive] def hiveThriftServerAsync: Boolean = - getConf("spark.sql.hive.thriftServer.async", "true").toBoolean + protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -165,7 +158,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - config = newTemporaryConfiguration()) + config = newTemporaryConfiguration(), + initClassLoader = Utils.getContextOrSparkClassLoader) } SessionState.setCurrentSessionState(executionHive.state) @@ -272,13 +266,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - // TODO: Database support... - catalog.refreshTable("default", tableName) + catalog.refreshTable(catalog.client.currentDatabase, tableName) } protected[hive] def invalidateTable(tableName: String): Unit = { - // TODO: Database support... - catalog.invalidateTable("default", tableName) + catalog.invalidateTable(catalog.client.currentDatabase, tableName) } /** @@ -367,17 +359,19 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + setConf(entry.key, entry.stringConverter(value)) + } + + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - override def conf: CatalystConf = currentSession().conf - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin)) /* An analyzer that uses the Hive metastore. */ @transient @@ -387,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil @@ -407,8 +401,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } /** @@ -449,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans @@ -524,7 +517,50 @@ private[hive] object HiveContext { val hiveExecutionVersion: String = "0.13.1" val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" - val HIVE_METASTORE_JARS: String = "spark.sql.hive.metastore.jars" + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", + defaultValue = Some("builtin"), + doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + + " property can be one of three options: " + + "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + + "-Phive is enabled. When this option is chosen, " + + "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + + "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + + "3. A classpath in the standard format for both Hive and Hadoop.") + + val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", + defaultValue = Some(true), + doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + + "the built in support.") + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( + "spark.sql.hive.convertMetastoreParquet.mergeSchema", + defaultValue = Some(false), + doc = "TODO") + + val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", + defaultValue = Some(false), + doc = "TODO") + + val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", + defaultValue = Some(jdbcPrefixes), + doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + + "classes that need to be shared are those that interact with classes that are already " + + "shared. For example, custom appenders that are used by log4j.") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") + + val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", + defaultValue = Some(Seq()), + doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") + + val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", + defaultValue = Some(true), + doc = "TODO") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c466203cd0220..a6b8ead577fb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -242,15 +243,16 @@ private[hive] trait HiveInspectors { def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null case poi: WritableConstantStringObjectInspector => - UTF8String(poi.getWritableConstantValue.toString) + UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) + UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => - poi.getWritableConstantValue.getTimestamp.clone() + val t = poi.getWritableConstantValue + t.getSeconds * 10000000L + t.getNanos / 100L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -271,7 +273,7 @@ private[hive] trait HiveInspectors { System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp case poi: WritableConstantDateObjectInspector => - DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) + DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data mi.getWritableConstantValue.map { case (k, v) => @@ -287,13 +289,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => - UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - UTF8String(x.getPrimitiveWritableObject(data).toString) + UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => - UTF8String(x.getPrimitiveJavaObject(data)) + UTF8String.fromString(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -311,13 +313,13 @@ private[hive] trait HiveInspectors { System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) - case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned + DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).getTimestamp.clone() - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + val t = x.getPrimitiveWritableObject(data) + t.getSeconds * 10000000L + t.getNanos / 100 + case ti: TimestampObjectInspector => + DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => @@ -334,9 +336,8 @@ private[hive] trait HiveInspectors { // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) + InternalRow.fromSeq( + allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -354,14 +355,17 @@ private[hive] trait HiveInspectors { (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => - (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + + case _: JavaTimestampObjectInspector => + (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct @@ -463,13 +467,13 @@ private[hive] trait HiveInspectors { case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) - case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) + case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] + val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 @@ -485,7 +489,7 @@ private[hive] trait HiveInspectors { result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] + val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { @@ -512,7 +516,7 @@ private[hive] trait HiveInspectors { } def wrap( - row: Row, + row: InternalRow, inspectors: Seq[ObjectInspector], cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 @@ -727,7 +731,7 @@ private[hive] trait HiveInspectors { TypeInfoFactory.voidTypeInfo, null) private def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) private def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) @@ -776,7 +780,7 @@ private[hive] trait HiveInspectors { if (value == null) { null } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) } private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5a4651a887b7c..f35ae96ee0b50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -143,7 +143,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName("default", tableName) + val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -302,7 +302,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val partitionColumnDataTypes = partitionSchema.map(_.dataType) val partitions = metastoreRelation.hiveQlPartitions.map { p => val location = p.getLocation - val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 9544d12c9053c..2de7a99c122fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -415,13 +415,6 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { @@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -1307,16 +1300,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveParser.DecimalLiteral) /* Case insensitive matches */ - val ARRAY = "(?i)ARRAY".r - val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r - val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r - val MAX = "(?i)MAX".r - val MIN = "(?i)MIN".r - val UPPER = "(?i)UPPER".r - val LOWER = "(?i)LOWER".r - val RAND = "(?i)RAND".r val AND = "(?i)AND".r val OR = "(?i)OR".r val NOT = "(?i)NOT".r @@ -1330,8 +1315,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r - val SUBSTR = "(?i)SUBSTR(?:ING)?".r - val SQRT = "(?i)SQRT".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -1353,18 +1336,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(name)) /* Aggregate Functions */ - case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) - - /* System functions about string operations */ - case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg)) /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => @@ -1414,7 +1388,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) @@ -1469,17 +1442,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("[", child :: ordinal :: Nil) => UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - /* Other functions */ - case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) => - CreateArray(children.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand() - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) - case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) - /* Window Functions */ case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) @@ -1676,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index fa5409f602444..d08c594151654 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,6 +20,11 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.hadoop.conf.Configuration @@ -35,10 +40,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.types.Decimal import org.apache.spark.util.Utils -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.reflect.ClassTag - private[hive] object HiveShim { // Precision and scale to pass for unlimited decimals; these are the same as the precision and // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) @@ -68,10 +69,10 @@ private[hive] object HiveShim { * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { + if (ids != null && ids.nonEmpty) { ColumnProjectionUtils.appendReadColumns(conf, ids) } - if (names != null && names.size > 0) { + if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) } } @@ -197,11 +198,11 @@ private[hive] object HiveShim { } /* - * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) f.setCompressCodec(w.compressCodec) f.setCompressType(w.compressType) f.setTableInfo(w.tableInfo) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index c6b65106452bf..452b7f0bcc749 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.expressions.{Row, _} +import org.apache.spark.sql.catalyst.expressions.{InternalRow, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -137,7 +137,7 @@ private[hive] trait HiveStrategies { val partitionLocations = partitions.map(_.getLocation) if (partitionLocations.isEmpty) { - PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil } else { hiveContext .read.parquet(partitionLocations: _*) @@ -165,7 +165,7 @@ private[hive] trait HiveStrategies { // TODO: Remove this hack for Spark 1.3. case iae: java.lang.IllegalArgumentException if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil } case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 334bfccc9d200..b251a9523bed6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ @@ -30,20 +29,21 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[Row] + def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] } @@ -72,9 +72,9 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) - override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, Class.forName( @@ -94,7 +94,7 @@ class HadoopTableReader( def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[Row] = { + filterOpt: Option[PathFilter]): RDD[InternalRow] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") @@ -125,7 +125,7 @@ class HadoopTableReader( deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -144,7 +144,7 @@ class HadoopTableReader( def makeRDDForPartitionedTable( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[Row] = { + filterOpt: Option[PathFilter]): RDD[InternalRow] = { // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists def verifyPartitionPath( @@ -243,7 +243,7 @@ class HadoopTableReader( // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Row](sc.sparkContext) + new EmptyRDD[InternalRow](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -276,7 +276,7 @@ class HadoopTableReader( val rdd = new HadoopRDD( sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, classOf[Writable], @@ -319,7 +319,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow, - tableDeser: Deserializer): Iterator[Row] = { + tableDeser: Deserializer): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] @@ -357,16 +357,16 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) + row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) @@ -391,7 +391,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { i += 1 } - mutableRow: Row + mutableRow: InternalRow } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 99aa0f1ded3f8..cbd2bf6b5eede 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive.client import java.io.{BufferedReader, InputStreamReader, File, PrintStream} import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import javax.annotation.concurrent.GuardedBy + +import org.apache.spark.util.CircularBuffer import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -27,7 +30,7 @@ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.metastore.{TableType => HTableType} import org.apache.hadoop.hive.metastore.api import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata @@ -54,49 +57,43 @@ import org.apache.spark.sql.execution.QueryExecutionException * @param version the version of hive used when pick function calls that are not compatible. * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this ClientWrapper. */ private[hive] class ClientWrapper( version: HiveVersion, - config: Map[String, String]) + config: Map[String, String], + initClassLoader: ClassLoader) extends ClientInterface - with Logging - with ReflectionMagic { + with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() - } + private val outputBuffer = new CircularBuffer() + + private val shim = version match { + case hive.v12 => new Shim_v0_12() + case hive.v13 => new Shim_v0_13() + case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } + // Create an internal session state for this ClientWrapper. val state = { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) val ret = try { val oldState = SessionState.get() if (oldState == null) { val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => logDebug(s"Hive Config: $k=$v") initialConf.set(k, v) @@ -119,23 +116,70 @@ private[hive] class ClientWrapper( def conf: HiveConf = SessionState.get().getConf // TODO: should be a def?s - private val client = Hive.get(conf) + // When we create this val client, the HiveConf of it (conf) is the one associated with state. + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = synchronized { + private def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) - version match { - case hive.v12 => - classOf[SessionState] - .callStatic[SessionState, SessionState]("start", state) - case hive.v13 => - classOf[SessionState] - .callStatic[SessionState, SessionState]("setCurrentSessionState", state) - } + // setCurrentSessionState will use the classLoader associated + // with the HiveConf in `state` to override the context class loader of the current + // thread. + shim.setCurrentSessionState(state) val ret = try f finally { Thread.currentThread().setContextClassLoader(original) } @@ -193,15 +237,12 @@ private[hive] class ClientWrapper( properties = h.getParameters.toMap, serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, tableType = h.getTableType match { - case TableType.MANAGED_TABLE => ManagedTable - case TableType.EXTERNAL_TABLE => ExternalTable - case TableType.VIRTUAL_VIEW => VirtualView - case TableType.INDEX_TABLE => IndexTable - }, - location = version match { - case hive.v12 => Option(h.call[URI]("getDataLocation")).map(_.toString) - case hive.v13 => Option(h.call[Path]("getDataLocation")).map(_.toString) + case HTableType.MANAGED_TABLE => ManagedTable + case HTableType.EXTERNAL_TABLE => ExternalTable + case HTableType.VIRTUAL_VIEW => VirtualView + case HTableType.INDEX_TABLE => IndexTable }, + location = shim.getDataLocation(h), inputFormat = Option(h.getInputFormatClass).map(_.getName), outputFormat = Option(h.getOutputFormatClass).map(_.getName), serde = Option(h.getSerializationLib), @@ -231,14 +272,7 @@ private[hive] class ClientWrapper( // set create time qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - version match { - case hive.v12 => - table.location.map(new URI(_)).foreach(u => qlTable.call[URI, Unit]("setDataLocation", u)) - case hive.v13 => - table.location - .map(new org.apache.hadoop.fs.Path(_)) - .foreach(qlTable.call[Path, Unit]("setDataLocation", _)) - } + table.location.foreach { loc => shim.setDataLocation(qlTable, loc) } table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) table.serde.foreach(qlTable.setSerializationLib) @@ -279,13 +313,7 @@ private[hive] class ClientWrapper( override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { val qlTable = toQlTable(hTable) - val qlPartitions = version match { - case hive.v12 => - client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsForPruner", qlTable) - case hive.v13 => - client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsOf", qlTable) - } - qlPartitions.toSeq.map(toHivePartition) + shim.getAllPartitions(client, qlTable).map(toHivePartition) } override def listTables(dbName: String): Seq[String] = withHiveState { @@ -315,15 +343,7 @@ private[hive] class ClientWrapper( val tokens: Array[String] = cmd_trimmed.split("\\s+") // The remainder of the command. val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - val proc: CommandProcessor = version match { - case hive.v12 => - classOf[CommandProcessorFactory] - .callStatic[String, HiveConf, CommandProcessor]("get", tokens(0), conf) - case hive.v13 => - classOf[CommandProcessorFactory] - .callStatic[Array[String], HiveConf, CommandProcessor]("get", Array(tokens(0)), conf) - } - + val proc = shim.getCommandProcessor(tokens(0), conf) proc match { case driver: Driver => val response: CommandProcessorResponse = driver.run(cmd) @@ -334,21 +354,7 @@ private[hive] class ClientWrapper( } driver.setMaxRows(maxRows) - val results = version match { - case hive.v12 => - val res = new JArrayList[String] - driver.call[JArrayList[String], Boolean]("getResults", res) - res.toSeq - case hive.v13 => - val res = new JArrayList[Object] - driver.call[JList[Object], Boolean]("getResults", res) - res.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } + val results = shim.getDriverResults(driver) driver.close() results @@ -382,8 +388,8 @@ private[hive] class ClientWrapper( holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { - - client.loadPartition( + shim.loadPartition( + client, new Path(loadPath), // TODO: Use URI tableName, partSpec, @@ -398,7 +404,8 @@ private[hive] class ClientWrapper( tableName: String, replace: Boolean, holdDDLTime: Boolean): Unit = withHiveState { - client.loadTable( + shim.loadTable( + client, new Path(loadPath), tableName, replace, @@ -413,7 +420,8 @@ private[hive] class ClientWrapper( numDP: Int, holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = withHiveState { - client.loadDynamicPartitions( + shim.loadDynamicPartitions( + client, new Path(loadPath), tableName, partSpec, @@ -428,7 +436,7 @@ private[hive] class ClientWrapper( logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala new file mode 100644 index 0000000000000..1fa9d278e2a57 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} +import java.lang.reflect.{Method, Modifier} +import java.net.URI +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState + +/** + * A shim that defines the interface between ClientWrapper and the underlying Hive library used to + * talk to the metastore. Each Hive version has its own implementation of this class, defining + * version-specific version of needed functions. + * + * The guideline for writing shims is: + * - always extend from the previous version unless really not possible + * - initialize methods in lazy vals, both for quicker access for multiple invocations, and to + * avoid runtime errors due to the above guideline. + */ +private[client] sealed abstract class Shim { + + /** + * Set the current SessionState to the given SessionState. Also, set the context classloader of + * the current thread to the one set in the HiveConf of this given `state`. + * @param state + */ + def setCurrentSessionState(state: SessionState): Unit + + /** + * This shim is necessary because the return type is different on different versions of Hive. + * All parameters are the same, though. + */ + def getDataLocation(table: Table): Option[String] + + def setDataLocation(table: Table, loc: String): Unit + + def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor + + def getDriverResults(driver: Driver): Seq[String] + + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + + def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + val method = findMethod(klass, name, args: _*) + require(Modifier.isStatic(method.getModifiers()), + s"Method $name of class $klass is not static.") + method + } + + protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + klass.getMethod(name, args: _*) + } + +} + +private[client] class Shim_v0_12 extends Shim { + + private lazy val startMethod = + findStaticMethod( + classOf[SessionState], + "start", + classOf[SessionState]) + private lazy val getDataLocationMethod = findMethod(classOf[Table], "getDataLocation") + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[URI]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsForPruner", + classOf[Table]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[String], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JArrayList[String]]) + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) + + override def setCurrentSessionState(state: SessionState): Unit = { + // Starting from Hive 0.13, setCurrentSessionState will internally override + // the context class loader of the current thread by the class loader set in + // the conf of the SessionState. So, for this Hive 0.12 shim, we add the same + // behavior and make shim.setCurrentSessionState of all Hive versions have the + // consistent behavior. + Thread.currentThread().setContextClassLoader(state.getConf.getClassLoader) + startMethod.invoke(null, state) + } + + override def getDataLocation(table: Table): Option[String] = + Option(getDataLocationMethod.invoke(table)).map(_.toString()) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new URI(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[String]() + getDriverResultsMethod.invoke(driver, res) + res.toSeq + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + } + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) + } + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + +} + +private[client] class Shim_v0_13 extends Shim_v0_12 { + + private lazy val setCurrentSessionStateMethod = + findStaticMethod( + classOf[SessionState], + "setCurrentSessionState", + classOf[SessionState]) + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[Path]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsOf", + classOf[Table]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[Array[String]], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JList[Object]]) + + override def setCurrentSessionState(state: SessionState): Unit = + setCurrentSessionStateMethod.invoke(null, state) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new Path(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[Object]() + getDriverResultsMethod.invoke(driver, res) + res.map { r => + r match { + case s: String => s + case a: Array[Object] => a(0).asInstanceOf[String] + } + } + } + +} + +private[client] class Shim_v0_14 extends Shim_v0_13 { + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + JBoolean.TRUE, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, + JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } +} + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 16851fdd71a98..3d609a66f3664 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.File +import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util @@ -28,6 +29,7 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext @@ -39,38 +41,41 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) new IsolatedClientLoader(hiveVersion(version), files, config) } def hiveVersion(version: String): HiveVersion = version match { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 + case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { - val hiveArtifacts = - (Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") ++ - (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil)) - .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+ - "com.google.guava:guava:14.0.1" :+ - "org.apache.hadoop:hadoop-client:2.4.0" + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + val hiveArtifacts = version.extraDeps ++ + Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") + .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ + Seq("com.google.guava:guava:14.0.1", + "org.apache.hadoop:hadoop-client:2.4.0") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None) + ivyPath, + exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet // TODO: Remove copy logic. - val tempDir = File.createTempFile("hive", "v" + version.toString) - tempDir.delete() - tempDir.mkdir() - + val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) tempDir.listFiles().map(_.toURL) } @@ -95,9 +100,8 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. - * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know - * about Hive classes. + * @param rootClassLoader The system root classloader. Must not know about Hive classes. + * @param baseClassLoader The spark classloader that is used to load shared classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,8 +114,8 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the base classloader does not know about Hive. - assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) + // Check to make sure that the root classloader does not know about Hive. + assert(Try(rootClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray @@ -129,7 +133,7 @@ private[hive] class IsolatedClientLoader( /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = name.startsWith(classOf[ClientWrapper].getName) || - name.startsWith(classOf[ReflectionMagic].getName) || + name.startsWith(classOf[Shim].getName) || barrierPrefixes.exists(name.startsWith) protected def classToPath(name: String): String = @@ -145,6 +149,7 @@ private[hive] class IsolatedClientLoader( def doLoadClass(name: String, resolve: Boolean): Class[_] = { val classFileName = name.replaceAll("\\.", "/") + ".class" if (isBarrierClass(name) && isolationOn) { + // For barrier classes, we construct a new copy of the class. val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") defineClass(name, bytes, 0, bytes.length) @@ -152,6 +157,7 @@ private[hive] class IsolatedClientLoader( logDebug(s"hive class: $name - ${getResource(classToPath(name))}") super.loadClass(name, resolve) } else { + // For shared classes, we delegate to baseClassLoader. logDebug(s"shared class: $name") baseClassLoader.loadClass(name) } @@ -167,14 +173,19 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[ClientWrapper].getName) .getConstructors.head - .newInstance(version, config) + .newInstance(version, config, classLoader) .asInstanceOf[ClientInterface] } catch { - case ReflectionException(cnf: NoClassDefFoundError) => - throw new ClassNotFoundException( - s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + - "Please make sure that jars for your version of hive and hadoop are included in the " + - s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + case e: InvocationTargetException => + if (e.getCause().isInstanceOf[NoClassDefFoundError]) { + val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] + throw new ClassNotFoundException( + s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + + "Please make sure that jars for your version of hive and hadoop are included in the " + + s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + } else { + throw e + } } finally { Thread.currentThread.setContextClassLoader(baseClassLoader) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala deleted file mode 100644 index 4d053ae42c2ea..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import scala.reflect._ - -/** Unwraps reflection exceptions. */ -private[client] object ReflectionException { - def unapply(a: Throwable): Option[Throwable] = a match { - case ite: java.lang.reflect.InvocationTargetException => Option(ite.getCause) - case _ => None - } -} - -/** - * Provides implicit functions on any object for calling methods reflectively. - */ -private[client] trait ReflectionMagic { - /** code for InstanceMagic - println( - (1 to 22).map { n => - def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") - val types = repeat(n => s"A$n <: AnyRef : ClassTag") - val inArgs = repeat(n => s"a$n: A$n") - val erasure = repeat(n => s"classTag[A$n].erasure") - val outArgs = repeat(n => s"a$n") - s"""|def call[$types, R](name: String, $inArgs): R = { - | clazz.getMethod(name, $erasure).invoke(a, $outArgs).asInstanceOf[R] - |}""".stripMargin - }.mkString("\n") - ) - */ - - // scalastyle:off - protected implicit class InstanceMagic(a: Any) { - private val clazz = a.getClass - - def call[R](name: String): R = { - clazz.getMethod(name).invoke(a).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { - clazz.getMethod(name, classTag[A1].erasure).invoke(a, a1).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(a, a1, a2).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(a, a1, a2, a3).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(a, a1, a2, a3, a4).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(a, a1, a2, a3, a4, a5).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(a, a1, a2, a3, a4, a5, a6).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] - } - } - - /** code for StaticMagic - println( - (1 to 22).map { n => - def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") - val types = repeat(n => s"A$n <: AnyRef : ClassTag") - val inArgs = repeat(n => s"a$n: A$n") - val erasure = repeat(n => s"classTag[A$n].erasure") - val outArgs = repeat(n => s"a$n") - s"""|def callStatic[$types, R](name: String, $inArgs): R = { - | c.getDeclaredMethod(name, $erasure).invoke(c, $outArgs).asInstanceOf[R] - |}""".stripMargin - }.mkString("\n") - ) - */ - - protected implicit class StaticMagic(c: Class[_]) { - def callStatic[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { - c.getDeclaredMethod(name, classTag[A1].erasure).invoke(c, a1).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(c, a1, a2).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(c, a1, a2, a3).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(c, a1, a2, a3, a4).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(c, a1, a2, a3, a4, a5).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(c, a1, a2, a3, a4, a5, a6).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] - } - } - // scalastyle:on -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 410d9881ac214..b48082fe4b363 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -19,15 +19,50 @@ package org.apache.spark.sql.hive /** Support for interacting with different versions of the HiveMetastoreClient */ package object client { - private[client] abstract class HiveVersion(val fullVersion: String, val hasBuiltinsJar: Boolean) + private[client] abstract class HiveVersion( + val fullVersion: String, + val extraDeps: Seq[String] = Nil, + val exclusions: Seq[String] = Nil) // scalastyle:off private[client] object hive { - case object v10 extends HiveVersion("0.10.0", true) - case object v11 extends HiveVersion("0.11.0", false) - case object v12 extends HiveVersion("0.12.0", false) - case object v13 extends HiveVersion("0.13.1", false) + case object v12 extends HiveVersion("0.12.0") + case object v13 extends HiveVersion("0.13.1") + + // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in + // maven central anymore, so override those with a version that exists. + // + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. + case object v14 extends HiveVersion("0.14.0", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + "org.apache.calcite:calcite-avatica:1.3.0-incubating"), + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on -} \ No newline at end of file +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 7d3ec12c4eb05..84358cb73c9e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -45,22 +43,30 @@ case class CreateTableAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextInputFormat - val withSchema = + val withFormat = tableDesc.copy( - schema = - query.output.map(c => - HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null)), inputFormat = tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), outputFormat = tableDesc.outputFormat .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)), serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName()))) + + val withSchema = if (withFormat.schema.isEmpty) { + // Hive doesn't support specifying the column list for target table in CTAS + // However we don't think SparkSQL should follow that. + tableDesc.copy(schema = + query.output.map(c => + HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null))) + } else { + withFormat + } + hiveContext.catalog.client.createTable(withSchema) // Get the Metastore Relation diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 6fce69b58b85e..5f0ed5393d191 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,12 +21,10 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} -import org.apache.spark.sql.execution.{SparkPlan, RunnableCommand} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} -import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 60a9bb630d0d9..41b645b2c9c93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -1,34 +1,34 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.types.StringType - -private[hive] -case class HiveNativeCommand(sql: String) extends RunnableCommand { - - override def output: Seq[AttributeReference] = - Seq(AttributeReference("result", StringType, nullable = false)()) - - override def run(sqlContext: SQLContext): Seq[Row] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} + +private[hive] +case class HiveNativeCommand(sql: String) extends RunnableCommand { + + override def output: Seq[AttributeReference] = + Seq(AttributeReference("result", StringType, nullable = false)()) + + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 11ee5503146b9..f4c8c9a7e8a68 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -123,13 +123,13 @@ case class HiveTableScan( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) + val row = InternalRow.fromSeq(castedValues) shouldKeep.eval(row).asInstanceOf[Boolean] } } } - protected override def doExecute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) { + protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index eeb472602be3c..05f425f2b65f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ +import org.apache.spark.util.SerializableJobConf private[hive] case class InsertIntoHiveTable( @@ -60,10 +62,10 @@ case class InsertIntoHiveTable( def output: Seq[Attribute] = child.output def saveAsHiveFile( - rdd: RDD[Row], + rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], + conf: SerializableJobConf, writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -82,7 +84,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -119,7 +121,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = { + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -171,7 +173,7 @@ case class InsertIntoHiveTable( } val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) + val jobConfSer = new SerializableJobConf(jobConf) val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) @@ -196,7 +198,6 @@ case class InsertIntoHiveTable( table.hiveQlTable.getPartCols().foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } - val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. @@ -250,12 +251,13 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[Row] + Seq.empty[InternalRow] } - override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeCollect(): Array[Row] = + sideEffectResult.toArray - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index fd623370cc407..b967e191c5855 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.lang.ProcessBuilder.Redirect import java.util.Properties import scala.collection.JavaConversions._ @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,19 +55,23 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) + // We need to start threads connected to the process pipeline: + // 1) The error msg generated by the script process would be hidden. + // 2) If the error msg is too big to chock up the buffer, the input logic would be hung val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream val reader = new BufferedReader(new InputStreamReader(inputStream)) val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) - val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { - var cacheRow: Row = null + val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + var cacheRow: InternalRow = null var curLine: String = null var eof: Boolean = false @@ -83,7 +88,7 @@ case class ScriptTransformation( } } - def deserialize(): Row = { + def deserialize(): InternalRow = { if (cacheRow != null) return cacheRow val mutableRow = new SpecificMutableRow(output.map(_.dataType)) @@ -113,7 +118,7 @@ case class ScriptTransformation( } } - override def next(): Row = { + override def next(): InternalRow = { if (!hasNext) { throw new NoSuchElementException } @@ -122,11 +127,11 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } @@ -145,28 +150,43 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) + // TODO make the 2048 configurable? + val stderrBuffer = new CircularBuffer(2048) + // Consume the error stream from the pipeline, otherwise it will be blocked if + // the pipeline is full. + new RedirectThread(errorStream, // input stream from the pipeline + stderrBuffer, // output to a circular buffer + "Thread-ScriptTransformation-STDERR-Consumer").start() + // Put the write(output to the pipeline) into a single thread // and keep the collector as remain in the main thread. // otherwise it will causes deadlock if the data size greater than // the pipeline / buffer capacity. new Thread(new Runnable() { override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + Utils.tryWithSafeFinally { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer } } - outputStream.close() } - }).start() + }, "Thread-ScriptTransformation-Feed").start() iterator } @@ -270,3 +290,4 @@ case class HiveScriptIOSchema ( } } } + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 0ba94d7b7c649..71fa3e9c33ad9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -91,9 +90,15 @@ case class AddJar(path: String) extends RunnableCommand { val jarURL = new java.io.File(path).toURL val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) Thread.currentThread.setContextClassLoader(newClassLoader) - org.apache.hadoop.hive.ql.metadata.Hive.get().getConf().setClassLoader(newClassLoader) - - // Add jar to isolated hive classloader + // We need to explicitly set the class loader associated with the conf in executionHive's + // state because this class loader will be used as the context class loader of the current + // thread to execute any Hive command. + // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get() + // returns the value of a thread local variable and its HiveConf may not be the HiveConf + // associated with `executionHive.state` (for example, HiveContext is created in one thread + // and then add jar is called from another thread). + hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader) + // Add jar to isolated hive (metadataHive) class loader. hiveContext.runSqlHive(s"ADD JAR $path") // Add jar to executors @@ -230,7 +235,7 @@ case class CreateMetastoreDataSourceAsSelect( val data = DataFrame(hiveContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. - case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema) + case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) case None => data } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 86% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 01f47352b2313..d7827d56ca8c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.spark.sql.AnalysisException - import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ +import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions @@ -30,9 +28,13 @@ import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -40,41 +42,44 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ -private[hive] abstract class HiveFunctionRegistry +private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + Try(underlying.lookupFunction(name, children)).getOrElse { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"undefined function $name")) + + val functionClassName = functionInfo.getFunctionClass.getName + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } } } + + override def registerFunction(name: String, builder: FunctionBuilder): Unit = + throw new UnsupportedOperationException } -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = UDF @@ -115,8 +120,10 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre @transient protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + override def isThreadSafe: Boolean = false + // TODO: Finish input output types. - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { unwrap( FunctionRegistry.invoke(method, function, conversionHelper .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), @@ -139,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF @@ -173,7 +180,9 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def eval(input: Row): Any = { + override def isThreadSafe: Boolean = false + + override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. var i = 0 @@ -340,7 +349,7 @@ private[hive] case class HiveWindowFunction( def nullable: Boolean = true - override def eval(input: Row): Any = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @transient @@ -364,7 +373,7 @@ private[hive] case class HiveWindowFunction( evaluator.reset(hiveEvaluatorBuffer) } - override def prepareInputParameters(input: Row): AnyRef = { + override def prepareInputParameters(input: InternalRow): AnyRef = { wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) } // Add input parameters for a single row. @@ -404,7 +413,7 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -432,11 +441,11 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -465,7 +474,7 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** @@ -479,7 +488,7 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors { @@ -507,7 +516,7 @@ private[hive] case class HiveGenericUdtf( field => (inspectorToDataType(field.getFieldObjectInspector), true) } - override def eval(input: Row): TraversableOnce[Row] = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) @@ -517,23 +526,23 @@ private[hive] case class HiveGenericUdtf( } protected class UDTFCollector extends Collector { - var collected = new ArrayBuffer[Row] + var collected = new ArrayBuffer[InternalRow] override def collect(input: java.lang.Object) { // We need to clone the input here because implementations of // GenericUDTF reuse the same object. Luckily they are always an array, so // it is easy to clone. - collected += unwrap(input, outputInspector).asInstanceOf[Row] + collected += unwrap(input, outputInspector).asInstanceOf[InternalRow] } - def collectRows(): Seq[Row] = { + def collectRows(): Seq[InternalRow] = { val toCollect = collected - collected = new ArrayBuffer[Row] + collected = new ArrayBuffer[InternalRow] toCollect } } - override def terminate(): TraversableOnce[Row] = { + override def terminate(): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. function.close() collector.collectRows() @@ -544,7 +553,7 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, @@ -573,7 +582,7 @@ private[hive] case class HiveUdafFunction( private val buffer = function.getNewAggregationBuffer - override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) + override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) @transient val inputProjection = new InterpretedProjection(exprs) @@ -581,7 +590,7 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) - def update(input: Row): Unit = { + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ee440e304ec19..ecc78a5f8d321 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,9 +34,10 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -57,7 +58,7 @@ private[hive] class SparkHiveWriterContainer( PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } - protected val conf = new SerializableWritable(jobConf) + protected val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -200,7 +201,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { - case DateType => DateUtils.toString(raw.toInt) + case DateType => DateTimeUtils.dateToString(raw.toInt) case _: DecimalType => BigDecimal(raw).toString() case _ => raw } @@ -227,12 +228,11 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) - val path = { - val outputPath = FileOutputFormat.getOutputPath(conf.value) - assert(outputPath != null, "Undefined job output-path") - val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) - new Path(workPath, getOutputName) - } + // use the path like ${hive_tmp}/_temporary/${attemptId}/ + // to avoid write to the same file when `spark.speculation=true` + val path = FileOutputFormat.getTaskOutputPath( + conf.value, + dynamicPartPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 1e51173a19882..e3ab9442b4821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -27,13 +27,13 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType -private[orc] object OrcFileOperator extends Logging{ +private[orc] object OrcFileOperator extends Logging { def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { val conf = config.getOrElse(new Configuration) val fspath = new Path(pathStr) val fs = fspath.getFileSystem(conf) val orcFiles = listOrcFiles(pathStr, conf) - + logDebug(s"Creating ORC Reader from ${orcFiles.head}") // TODO Need to consider all files when schema evolution is taken into account. OrcFile.createReader(fs, orcFiles.head) } @@ -42,6 +42,7 @@ private[orc] object OrcFileOperator extends Logging{ val reader = getFileReader(path, conf) val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $path, got Hive schema string: $schema") HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } @@ -52,14 +53,14 @@ private[orc] object OrcFileOperator extends Logging{ def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) - val path = origPath.makeQualified(fs) + val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) .filterNot(_.isDir) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) - if (paths == null || paths.size == 0) { + if (paths == null || paths.isEmpty) { throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index f03c4cd54e7e6..300f83d914ea4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -27,10 +27,11 @@ import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} @@ -39,7 +40,7 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -104,13 +105,14 @@ private[orc] class OrcOutputWriter( recordWriterInstantiated = true val conf = context.getConfiguration + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val partition = context.getTaskAttemptID.getTaskID.getId - val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), conf.asInstanceOf[JobConf], - new Path(path, filename).toUri.getPath, + new Path(path, filename).toString, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } @@ -188,10 +190,20 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute() + OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + new OutputWriterFactory { override def newInstance( path: String, @@ -222,13 +234,13 @@ private[orc] case class OrcTableScan( HiveShim.appendReadColumns(conf, sortedIds, sortedNames) } - // Transform all given raw `Writable`s into `Row`s. + // Transform all given raw `Writable`s into `InternalRow`s. private def fillObject( path: String, conf: Configuration, iterator: Iterator[Writable], nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow): Iterator[InternalRow] = { val deserializer = new OrcSerde val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { @@ -249,11 +261,11 @@ private[orc] case class OrcTableScan( } i += 1 } - mutableRow: Row + mutableRow: InternalRow } } - def execute(): RDD[Row] = { + def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration @@ -283,7 +295,7 @@ private[orc] case class OrcTableScan( classOf[Writable] ).asInstanceOf[HadoopRDD[NullWritable, Writable]] - val wrappedConf = new SerializableWritable(conf) + val wrappedConf = new SerializableConfiguration(conf) rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7c7afc824d7a6..7978fdacaedba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -49,7 +49,7 @@ import scala.collection.JavaConversions._ object TestHive extends TestHiveContext( new SparkContext( - "local[2]", + System.getProperty("spark.sql.test.master", "local[32]"), "TestSQLContext", new SparkConf() .set("spark.sql.test", "") @@ -112,12 +112,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { protected[hive] class SQLSession extends super.SQLSession { /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } } @@ -392,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -411,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 index 27de46fdf22ac..84a31a5a6970b 100644 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -1 +1 @@ --0.0010000000000000009 +-0.001 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 new file mode 100644 index 0000000000000..7e5fceeddeeeb --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 @@ -0,0 +1,97 @@ +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 2 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 6 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 34 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 2 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 34 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 2 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 6 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 28 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 34 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 2 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 6 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 28 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 34 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 42 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 6 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 28 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 34 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 42 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 6 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 28 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 42 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 2 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 14 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 40 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 2 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 14 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 25 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 40 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 2 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 14 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 18 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 25 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 40 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 2 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 18 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 25 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 40 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 2 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 18 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 25 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 14 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 17 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 19 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 1 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 14 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 17 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 19 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 1 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 14 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 17 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 19 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 45 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 1 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 14 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 19 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 45 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 1 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 19 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 45 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 10 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 27 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 39 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 7 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 10 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 27 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 39 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 7 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 10 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 12 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 27 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 39 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 7 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 12 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 27 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 39 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 7 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 12 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 27 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 2 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 6 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 31 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 2 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 6 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 31 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 46 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 2 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 6 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 23 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 31 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 46 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 2 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 6 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 23 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 46 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 2 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 23 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 46 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 deleted file mode 100644 index 1f7e8a5d67036..0000000000000 --- a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 +++ /dev/null @@ -1,26 +0,0 @@ -Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 [34,2] 74912.8826888888 1.0 4128.782222222221 -Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 [34,2,6] 66619.10876874991 0.811328754177887 2801.7074999999995 -Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 [34,2,6,28] 53315.51002399992 0.695639377397664 2210.7864 -Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 [34,2,6,42,28] 41099.896184 0.630785977101214 2009.9536000000007 -Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 [34,6,42,28] 14788.129118750014 0.2036684720435979 331.1337500000004 -Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 [6,42,28] 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 -Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 [2,40,14] 20231.169866666663 -0.49369526554523185 -1113.7466666666658 -Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 [2,25,40,14] 18978.662075 -0.5205630897335946 -1004.4812499999995 -Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 [2,18,25,40,14] 16910.329504000005 -0.46908967495720255 -766.1791999999995 -Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 [2,18,25,40] 18374.07627499999 -0.6091405874714462 -1128.1787499999987 -Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 [2,18,25] 24473.534488888927 -0.9571686373491608 -1441.4466666666676 -Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 [17,19,14] 38720.09628888887 0.5557168646224995 224.6944444444446 -Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 [17,1,19,14] 75702.81305 -0.6720833036576083 -1296.9000000000003 -Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 [17,1,19,14,45] 67722.117896 -0.5703526513979519 -2129.0664 -Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 [1,19,14,45] 76128.53331875012 -0.577476899644802 -2547.7868749999993 -Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 [1,19,45] 67902.76602222225 -0.8710736366736884 -4099.731111111111 -Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 [39,27,10] 28944.25735555559 -0.6656975320098423 -1347.4777777777779 -Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 [39,7,27,10] 58693.95151875002 -0.8051852719193339 -2537.328125 -Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 [39,7,27,10,12] 54802.817784000035 -0.6046935574240581 -1719.8079999999995 -Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 [39,7,27,12] 61174.24181875003 -0.5508665654707869 -1719.0368749999975 -Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 [7,27,12] 80278.40095555557 -0.7755740084632333 -1867.4888888888881 -Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 [2,6,31] 7005.487488888913 0.39004303087285047 418.9233333333353 -Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 [2,6,46,31] 100286.53662500004 -0.713612911776183 -4090.853749999999 -Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 [2,23,6,46,31] 81456.04997600002 -0.712858514567818 -3297.2011999999986 -Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 [2,23,6,46] 81474.56091875004 -0.984128787153391 -4871.028125000002 -Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 [2,23,46] 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/hive-contrib-0.13.1.jar b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar new file mode 100644 index 0000000000000..ce0740d9245a7 Binary files /dev/null and b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala new file mode 100644 index 0000000000000..0e428ba1d7456 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.HiveContext + +/** + * Entry point in test application for SPARK-8489. + * + * This file is not meant to be compiled during tests. It is already included + * in a pre-built "test.jar" located in the same directory as this file. + * This is included here for reference only and should NOT be modified without + * rebuilding the test jar itself. + * + * This is used in org.apache.spark.sql.hive.HiveSparkSubmitSuite. + */ +object Main { + def main(args: Array[String]) { + println("Running regression test for SPARK-8489.") + val sc = new SparkContext("local", "testing") + val hc = new HiveContext(sc) + // This line should not throw scala.reflect.internal.MissingRequirementError. + // See SPARK-8470 for more detail. + val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) + df.collect() + println("Regression test for SPARK-8489 success!") + sc.stop() + } +} + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala new file mode 100644 index 0000000000000..b1681745c2ef7 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Dummy class used in regression test SPARK-8489. */ +case class MyCoolClass(past: String, present: String, future: String) + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar new file mode 100644 index 0000000000000..5944aa6076a5f Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index df137e7b2b333..a93acb938d5fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -28,8 +28,9 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Literal, Row} +import org.apache.spark.sql.catalyst.expressions.{Literal, InternalRow} import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { @@ -45,7 +46,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { classOf[UDAFPercentile.State], ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] - val a = unwrap(state, soi).asInstanceOf[Row] + val a = unwrap(state, soi).asInstanceOf[InternalRow] val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") @@ -127,7 +128,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: Row): Unit = { + def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { row1.zip(row2.toSeq).foreach { case (r1, r2) => checkValue(r1, r2) } @@ -201,9 +202,9 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) }) - + val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 5a5ea10e3c82e..af68615e8e9d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, SQLConf} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} case class Cases(lower: String, UPPER: String) @@ -82,11 +81,11 @@ class HiveParquetSuite extends QueryTest with ParquetTest { } } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { run("Parquet data source enabled") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { run("Parquet data source disabled") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala new file mode 100644 index 0000000000000..a38ed23b5cf9a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.sys.process.{ProcessLogger, Process} + +import org.apache.spark._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +/** + * This suite tests spark-submit with applications using HiveContext. + */ +class HiveSparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + + // TODO: rewrite these or mark them as slow tests to be run sparingly + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + test("SPARK-8368: includes jars passed in through --jars") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") + val args = Seq( + "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSubmitClassLoaderTest", + "--master", "local-cluster[2,1,512]", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("SPARK-8020: set sql conf in spark conf") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,512]", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-8489: MissingRequirementError during reflection") { + // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates + // a HiveContext and uses it to create a data frame from an RDD using reflection. + // Before the fix in SPARK-8470, this results in a MissingRequirementError because + // the HiveContext code mistakenly overrides the class loader that contains user classes. + // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. + val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" + val args = Seq("--class", "Main", testJar) + runSparkSubmit(args) + } + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + private def runSparkSubmit(args: Seq[String]): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val process = Process( + Seq("./bin/spark-submit") ++ args, + new File(sparkHome), + "SPARK_TESTING" -> "1", + "SPARK_HOME" -> sparkHome + ).run(ProcessLogger( + (line: String) => { println(s"out> $line") }, + (line: String) => { println(s"err> $line") } + )) + + try { + val exitCode = failAfter(180 seconds) { process.exitValue() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + +// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. +// We test if we can load user jars in both driver and executors when HiveContext is used. +object SparkSubmitClassLoaderTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") + // First, we load classes at driver side. + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + throw new Exception("Could not load user class from jar:\n", t) + } + // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") + val result = df.mapPartitions { x => + var exception: String = null + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") + } + Option(exception).toSeq.iterator + }.collect() + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) + } + + // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") + hiveContext.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") + hiveContext.sql( + """ + |CREATE TABLE t1(key int, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") + hiveContext.sql( + "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = hiveContext.table("t1").orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"table t1 should have 10 rows instead of $count rows") + } + logInfo("Test finishes.") + sc.stop() + } +} + +// This object is used for testing SPARK-8020: https://issues.apache.org/jira/browse/SPARK-8020. +// We test if we can correctly set spark sql configurations when HiveContext is used. +object SparkSQLConfTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + // We override the SparkConf to add spark.sql.hive.metastore.version and + // spark.sql.hive.metastore.jars to the beginning of the conf entry array. + // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but + // before spark.sql.hive.metastore.jars get set, we will see the following exception: + // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only + // be used when hive execution version == hive metastore version. + // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. + val conf = new SparkConf() { + override def getAll: Array[(String, String)] = { + def isMetastoreSetting(conf: String): Boolean = { + conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + } + // If there is any metastore settings, remove them. + val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) + + // Always add these two metastore settings at the beginning. + ("spark.sql.hive.metastore.version" -> "0.12") +: + ("spark.sql.hive.metastore.jars" -> "maven") +: + filteredSettings + } + + // For this simple test, we do not really clone this object. + override def clone: SparkConf = this + } + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Run a simple command to make sure all lazy vals in hiveContext get instantiated. + hiveContext.tables().collect() + sc.stop() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index af586712e3235..cc294bc3e8bc3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -456,7 +456,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA withTable("savedJsonTable") { val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { // Save the df as a managed table (by not specifying the path). df.write.saveAsTable("savedJsonTable") @@ -484,7 +484,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } // Create an external table by specifying the path. - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write .format("org.apache.spark.sql.json") .mode(SaveMode.Append) @@ -508,7 +508,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA s"""{ "a": $i, "b": "str$i" }""" })) - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write .format("json") .mode(SaveMode.Append) @@ -516,7 +516,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA .saveAsTable("savedJsonTable") } - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { createExternalTable("createdJsonTable", tempPath.toString) assert(table("createdJsonTable").schema === df.schema) checkAnswer(sql("SELECT * FROM createdJsonTable"), df) @@ -533,7 +533,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA checkAnswer(read.json(tempPath.toString), df) // Try to specify the schema. - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { val schema = StructType(StructField("b", StringType, true) :: Nil) createExternalTable( "createdJsonTable", @@ -563,8 +563,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA test("scan a parquet table created through a CTAS statement") { withSQLConf( - "spark.sql.hive.convertMetastoreParquet" -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { withTempTable("jt") { (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") @@ -706,7 +706,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } test("SPARK-6024 wide schema support") { - withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") { + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { withTable("wide_schema") { // We will need 80 splits for this schema if the threshold is 4000. val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) @@ -833,4 +833,21 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA (70 to 79).map(i => Row(i, s"str$i"))) } } + + test("SPARK-8156:create table to specific database by 'use dbname' ") { + + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + sqlContext.sql("""create database if not exists testdb8156""") + sqlContext.sql("""use testdb8156""") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") + + checkAnswer( + sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + Row("ttt3", false)) + sqlContext.sql("""use default""") + sqlContext.sql("""drop database if exists testdb8156 CASCADE""") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 78c94e6490e36..f067ea0d4fc75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -167,7 +167,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -176,7 +176,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j @@ -238,7 +238,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7eb4842726665..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.client +import java.io.File + import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -28,6 +30,12 @@ import org.apache.spark.util.Utils * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -38,7 +46,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -67,19 +75,21 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -170,5 +180,12 @@ class VersionsSuite extends SparkFunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6d8d99ebc8164..51dabc67fa7c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1084,14 +1084,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - val KV = "([^=]+)=([^=]*)".r - def collectResults(df: DataFrame): Set[(String, String)] = + def collectResults(df: DataFrame): Set[Any] = df.collect().map { case Row(key: String, value: String) => key -> value - case Row(KV(key, value)) => key -> value + case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet conf.clear() + val expectedConfs = conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(sql("SET -v"))) + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -1102,16 +1104,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET -v")) - } // "SET key" assertResult(Set(testKey -> testVal)) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8c73fb4..197e9bfb02c4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 93% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ce5985888f540..56b0bef1d0571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest { import TestHive.{udf, sql} import TestHive.implicits._ @@ -73,7 +73,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +82,15 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("SPARK-6409 UDAFAverage test") { @@ -169,11 +169,11 @@ class HiveUdfSuite extends QueryTest { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() } @@ -244,7 +244,7 @@ class PairSerDe extends AbstractSerDe { } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 40a35674e4cb8..9f7e58f890241 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -191,9 +191,9 @@ class SQLQuerySuite extends QueryTest { } } - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + val originalConf = convertCTAS - setConf("spark.sql.hive.convertCTAS", "true") + setConf(HiveContext.CONVERT_CTAS, true) sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") @@ -235,7 +235,7 @@ class SQLQuerySuite extends QueryTest { checkRelation("ctas1", false) sql("DROP TABLE ctas1") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("SQL Dialect Switching") { @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { val origUseParquetDataSource = conf.parquetUseDataSourceApi try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) sql( """CREATE TABLE ctas5 | STORED AS parquet AS @@ -348,7 +348,7 @@ class SQLQuerySuite extends QueryTest { "MANAGED_TABLE" ) - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + val default = convertMetastoreParquet // use the Hive SerDe for parquet tables sql("set spark.sql.hive.convertMetastoreParquet = false") checkAnswer( @@ -356,10 +356,28 @@ class SQLQuerySuite extends QueryTest { sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) sql(s"set spark.sql.hive.convertMetastoreParquet = $default") } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } + test("specifying the column list for CTAS") { + Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") + + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("create table gen__tmp(a double, b double) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select cast(key as double), cast(value as double) from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("drop table mytable1") + } + test("command substitution") { sql("set tbl=src") checkAnswer( @@ -585,8 +603,8 @@ class SQLQuerySuite extends QueryTest { // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) read.json(rdd).registerTempTable("data") - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") - setConf("spark.sql.hive.convertCTAS", "false") + val originalConf = convertCTAS + setConf(HiveContext.CONVERT_CTAS, false) sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { @@ -603,7 +621,7 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE explodeTest") dropTempTable("data") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("sanity test for SPARK-6618") { @@ -620,19 +638,27 @@ class SQLQuerySuite extends QueryTest { test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") - .select($"d1".cast(DecimalType(10, 15)).as("d")) + .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } - test("test script transform") { + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + .queryExecution.toRdd.count()) + } + + test("test script transform for stderr") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(0 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) } test("window function: udaf with aggregate expressin") { @@ -849,6 +875,10 @@ class SQLQuerySuite extends QueryTest { } } + test("Cast STRING to BIGINT") { + checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) + } + // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest test("udf_java_method") { checkAnswer(sql( @@ -904,4 +934,32 @@ class SQLQuerySuite extends QueryTest { sql("set hive.exec.dynamic.partition.mode=strict") } } + + test("Call add jar in a different thread (SPARK-8306)") { + @volatile var error: Option[Throwable] = None + val thread = new Thread { + override def run() { + // To make sure this test works, this jar should not be loaded in another place. + TestHive.sql( + s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + try { + TestHive.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + } catch { + case throwable: Throwable => + error = Some(throwable) + } + } + } + thread.start() + thread.join() + error match { + case Some(throwable) => + fail("CREATE TEMPORARY FUNCTION should not fail.", throwable) + case None => // OK + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 0e63d84e9824a..8707f9f936be6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b384fb39f3d66..267d22c6b5f1e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf46457..a0cdd0db42d65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -43,8 +43,14 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir.mkdir() import org.apache.spark.sql.hive.test.TestHive.implicits._ + // Originally we were using a 10-row RDD for testing. However, when default parallelism is + // greater than 10 (e.g., running on a node with 32 cores), this RDD contains empty partitions, + // which result in empty ORC files. Unfortunately, ORC doesn't handle empty files properly and + // causes build failure on Jenkins, which happens to have 32 cores. Please refer to SPARK-8501 + // for more details. To workaround this issue before fixing SPARK-8501, we simply increase row + // number in this RDD to avoid empty partitions. sparkContext - .makeRDD(1 to 10) + .makeRDD(1 to 100) .map(i => OrcData(i, s"part-$i")) .toDF() .registerTempTable(s"orc_temp_table") @@ -70,35 +76,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { } test("create temporary orc table") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(100)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 10).map(i => Row(i, s"part-$i"))) + (1 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source where intField > 5"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 10).map(i => Row(1, s"part-$i"))) + (1 to 100).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(100)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 10).map(i => Row(i, s"part-$i"))) + (1 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source WHERE intField > 5"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 10).map(i => Row(1, s"part-$i"))) + (1 to 100).map(i => Row(1, s"part-$i"))) } test("appending insert") { @@ -106,7 +112,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 100).flatMap { i => Seq.fill(2)(Row(i, s"part-$i")) }) } @@ -119,7 +125,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_as_source"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index e62ac909cbd0c..c2e09800933b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,8 +21,6 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ @@ -30,7 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -155,7 +153,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) read.json(rdd2).registerTempTable("jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "true") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } override def afterAll(): Unit = { @@ -166,7 +164,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { sql("DROP TABLE normal_parquet") sql("DROP TABLE IF EXISTS jt") sql("DROP TABLE IF EXISTS jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "false") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { @@ -201,14 +199,14 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() sql("DROP TABLE IF EXISTS test_parquet") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("scan an empty parquet table") { @@ -548,12 +546,12 @@ class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("MetastoreRelation in InsertIntoTable will not be converted") { @@ -694,12 +692,12 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("values in arrays and maps stored in parquet are always nullable") { @@ -752,12 +750,12 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 0f959b3d0b86d..e8141923a9b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -53,9 +53,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") val split = context.getTaskAttemptID.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-${UUID.randomUUID()}") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } } @@ -118,6 +119,8 @@ class SimpleTextRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + override def newInstance( path: String, dataSchema: StructType, @@ -156,6 +159,7 @@ class CommitFailureTestRelation( context: TaskAttemptContext): OutputWriter = { new SimpleTextOutputWriter(path, context) { override def close(): Unit = { + super.close() sys.error("Intentional task commitment failure for testing purpose.") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 76469d7a3d6a5..afecf9675e11f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,16 @@ package org.apache.spark.sql.sources +import scala.collection.JavaConversions._ + import java.io.File import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil @@ -35,7 +41,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { import sqlContext.sql import sqlContext.implicits._ - val dataSourceName = classOf[SimpleTextSource].getCanonicalName + val dataSourceName: String val dataSchema = StructType( @@ -470,6 +476,108 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } + + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + test("SPARK-8406: Avoids name collision while writing files") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) + + assertResult(10000) { + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() + } + } + } + + test("SPARK-8578 specified custom output committer will not be used to append data") { + val clonedConf = new Configuration(configuration) + try { + val df = sqlContext.range(1, 10).toDF("i") + withTempPath { dir => + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there data already exists, + // this append should succeed because we will use the output committer associated + // with file format and AlwaysFailOutputCommitter will not be used. + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext.read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df.unionAll(df)) + + // This will fail because AlwaysFailOutputCommitter is used when we do append. + intercept[Exception] { + df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath) + } + } + withTempPath { dir => + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there is no existing data, + // this append will fail because AlwaysFailOutputCommitter is used when we do append + // and there is no existing data. + intercept[Exception] { + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + } + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailParquetOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -502,15 +610,17 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - import TestHive.implicits._ - override val sqlContext = TestHive + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName test("SPARK-7684: commitTask() failure should fallback to abortTask()") { withTempPath { file => - val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } @@ -609,4 +719,25 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } } diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css index b22c884bfebdb..ec12616b58d87 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css @@ -31,7 +31,7 @@ } .tooltip-inner { - max-width: 500px !important; // Make sure we only have one line tooltip + max-width: 500px !important; /* Make sure we only have one line tooltip */ } .line { diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 75251f493ad22..4886b68eeaf76 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -31,6 +31,8 @@ var maxXForHistogram = 0; var histogramBinCount = 10; var yValueFormat = d3.format(",.2f"); +var unitLabelYOffset = -10; + // Show a tooltip "text" for "node" function showBootstrapTooltip(node, text) { $(node).tooltip({title: text, trigger: "manual", container: "body"}); @@ -133,7 +135,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "y axis") .call(yAxis) .append("text") - .attr("transform", "translate(0," + (-3) + ")") + .attr("transform", "translate(0," + unitLabelYOffset + ")") .text(unitY); @@ -223,10 +225,10 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .style("border-left", "0px solid white"); var margin = {top: 20, right: 30, bottom: 30, left: 10}; - var width = 300 - margin.left - margin.right; + var width = 350 - margin.left - margin.right; var height = 150 - margin.top - margin.bottom; - var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width]); + var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5); @@ -248,7 +250,7 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .attr("class", "x axis") .call(xAxis) .append("text") - .attr("transform", "translate(" + (margin.left + width - 40) + ", 15)") + .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")") .text("#batches"); svg.append("g") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 9cd9684d36404..1708f309fc002 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -549,8 +549,8 @@ class StreamingContext private[streaming] ( case e: NotSerializableException => throw new NotSerializableException( "DStream checkpointing has been enabled but the DStreams with their functions " + - "are not serializable\nSerialization stack:\n" + - SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n") + "are not serializable\n" + + SerializationDebugger.improveException(checkpoint, e).getMessage() ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 6c1fab56740ee..86a8e2beff57c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,10 +26,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -78,7 +77,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { - private val serializableConfOpt = conf.map(new SerializableWritable(_)) + private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) /** * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 358e4c66df7ba..71bec96d46c8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,10 +24,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -688,7 +689,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) @@ -721,7 +722,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: Configuration = ssc.sparkContext.hadoopConfiguration ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableConfiguration(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ffce6a4c3c74c..31ce8e1ec14d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -23,12 +23,11 @@ import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ +import org.apache.spark.util.SerializableConfiguration /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -94,7 +93,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration - private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig) override def isValid(): Boolean = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 8d73593ab6375..92b51ce39234c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{SystemClock, Utils} +import org.apache.spark.util.SystemClock /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -80,6 +80,8 @@ private[streaming] class BlockGenerator( private val clock = new SystemClock() private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") + require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") + private val blockIntervalTimer = new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 207d64d9414ee..c8dd6e06812dc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult( + blockId: StreamBlockId, numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -64,11 +68,20 @@ private[streaming] class BlockManagerBasedBlockHandler( extends ReceivedBlockHandler with Logging { def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + var numRecords = None: Option[Long] + val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + numRecords = Some(arrayBuffer.size.toLong) + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, + tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + val countIterator = new CountingIterator(iterator) + val putResult = blockManager.putIterator(blockId, countIterator, storageLevel, + tellMaster = true) + numRecords = countIterator.count + putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, + numRecords: Option[Long], walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -151,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => + numRecords = Some(arrayBuffer.size.toLong) blockManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + val countIterator = new CountingIterator(iterator) + val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + numRecords = countIterator.count + serializedBlock case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -181,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, walRecordHandle) + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { @@ -199,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { + private var _count = 0 + + private def isFullyConsumed: Boolean = !iterator.hasNext + + def hasNext(): Boolean = iterator.hasNext + + def count(): Option[Long] = { + if (isFullyConsumed) Some(_count) else None + } + + def next(): T = { + _count += 1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8be732b64e3a3..6078cdf8f8790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) - case _ => None - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - + val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f1504b09c9873..e6cdbec11e94c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -21,10 +21,12 @@ import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, + StopReceiver} +import org.apache.spark.util.SerializableConfiguration /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -294,7 +296,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } val checkpointDirOption = Option(ssc.checkpointDir) - val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + val serializableHadoopConf = + new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4ee7a486e370b..87af902428ec8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -310,7 +310,7 @@ private[ui] class StreamingPage(parent: StreamingTab) Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed) - Histograms + Histograms @@ -456,7 +456,7 @@ private[ui] class StreamingPage(parent: StreamingTab) {receiverActive} {receiverLocation} {receiverLastErrorTime} -
{receiverLastError}
+
{receiverLastError}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index cca8cedb1d080..6c0c926755c20 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -57,10 +56,12 @@ class ReceivedBlockHandlerSuite val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { @@ -70,20 +71,21 @@ class ReceivedBlockHandlerSuite blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null @@ -174,6 +176,130 @@ class ReceivedBlockHandlerSuite } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index be305b5e0dfea..f793a12843b2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1dc8960d60528..7bc7727a9fbe4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index cbc24aee4fa1e..a08578680cff9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -27,9 +27,10 @@ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.ui.SparkUICssErrorHandler /** - * Selenium tests for the Spark Web UI. + * Selenium tests for the Spark Streaming Web UI. */ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { @@ -37,7 +38,9 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } } override def afterAll(): Unit = { diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 62c6354f1e203..33782c6c66f90 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -67,7 +67,7 @@
org.mockito - mockito-all + mockito-core test @@ -80,7 +80,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java new file mode 100644 index 0000000000000..9302b472925ed --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import javax.annotation.Nonnull; +import java.io.Serializable; +import java.io.UnsupportedEncodingException; +import java.util.Arrays; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A UTF-8 String for internal Spark use. + *

+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + *

+ * Note: This is not designed for general use cases, should not be used outside SQL. + */ +public final class UTF8String implements Comparable, Serializable { + + @Nonnull + private byte[] bytes; + + private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6}; + + public static UTF8String fromBytes(byte[] bytes) { + return (bytes != null) ? new UTF8String().set(bytes) : null; + } + + public static UTF8String fromString(String str) { + return (str != null) ? new UTF8String().set(str) : null; + } + + /** + * Updates the UTF8String with String. + */ + protected UTF8String set(final String str) { + try { + bytes = str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + PlatformDependent.throwException(e); + } + return this; + } + + /** + * Updates the UTF8String with byte[], which should be encoded in UTF-8. + */ + protected UTF8String set(final byte[] bytes) { + this.bytes = bytes; + return this; + } + + /** + * Returns the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + public int numBytes(final byte b) { + final int offset = (b & 0xFF) - 192; + return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + } + + /** + * Returns the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + public int length() { + int len = 0; + for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) { + len += 1; + } + return len; + } + + public byte[] getBytes() { + return bytes; + } + + /** + * Returns a substring of this. + * @param start the position of first code point + * @param until the position after last code point, exclusive. + */ + public UTF8String substring(final int start, final int until) { + if (until <= start || start >= bytes.length) { + return UTF8String.fromBytes(new byte[0]); + } + + int i = 0; + int c = 0; + for (; i < bytes.length && c < start; i += numBytes(bytes[i])) { + c += 1; + } + + int j = i; + for (; j < bytes.length && c < until; j += numBytes(bytes[i])) { + c += 1; + } + + return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j)); + } + + public boolean contains(final UTF8String substring) { + final byte[] b = substring.getBytes(); + if (b.length == 0) { + return true; + } + + for (int i = 0; i <= bytes.length - b.length; i++) { + if (bytes[i] == b[0] && startsWith(b, i)) { + return true; + } + } + return false; + } + + private boolean startsWith(final byte[] prefix, int offsetInBytes) { + if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) { + return false; + } + int i = 0; + while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) { + i++; + } + return i == prefix.length; + } + + public boolean startsWith(final UTF8String prefix) { + return startsWith(prefix.getBytes(), 0); + } + + public boolean endsWith(final UTF8String suffix) { + return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length); + } + + public UTF8String toUpperCase() { + return UTF8String.fromString(toString().toUpperCase()); + } + + public UTF8String toLowerCase() { + return UTF8String.fromString(toString().toLowerCase()); + } + + @Override + public String toString() { + try { + return new String(bytes, "utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + PlatformDependent.throwException(e); + return "unknown"; // we will never reach here. + } + } + + @Override + public UTF8String clone() { + return new UTF8String().set(bytes); + } + + @Override + public int compareTo(final UTF8String other) { + final byte[] b = other.getBytes(); + for (int i = 0; i < bytes.length && i < b.length; i++) { + int res = bytes[i] - b[i]; + if (res != 0) { + return res; + } + } + return bytes.length - b.length; + } + + public int compare(final UTF8String other) { + return compareTo(other); + } + + @Override + public boolean equals(final Object other) { + if (other instanceof UTF8String) { + return Arrays.equals(bytes, ((UTF8String) other).getBytes()); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java index 18393db9f382f..a93fc0ee297c4 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.bitset; import junit.framework.Assert; -import org.apache.spark.unsafe.bitset.BitSet; import org.junit.Test; import org.apache.spark.unsafe.memory.MemoryBlock; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java new file mode 100644 index 0000000000000..796cdc9dbebdb --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -0,0 +1,91 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import java.io.UnsupportedEncodingException; + +import junit.framework.Assert; +import org.junit.Test; + +public class UTF8StringSuite { + + private void checkBasic(String str, int len) throws UnsupportedEncodingException { + Assert.assertEquals(UTF8String.fromString(str).length(), len); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len); + + Assert.assertEquals(UTF8String.fromString(str).toString(), str); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str)); + + Assert.assertEquals(UTF8String.fromString(str).hashCode(), + UTF8String.fromBytes(str.getBytes("utf8")).hashCode()); + } + + @Test + public void basicTest() throws UnsupportedEncodingException { + checkBasic("hello", 5); + checkBasic("世 界", 3); + } + + @Test + public void contains() { + Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello"))); + Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello"))); + Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo"))); + Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世"))); + Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千"))); + Assert.assertFalse( + UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好"))); + } + + @Test + public void startsWith() { + Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell"))); + Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell"))); + Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo"))); + Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据"))); + Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("千"))); + Assert.assertFalse( + UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好"))); + } + + @Test + public void endsWith() { + Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello"))); + Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov"))); + Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello"))); + Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界"))); + Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世"))); + Assert.assertFalse( + UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头"))); + } + + @Test + public void substring() { + Assert.assertEquals( + UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString("")); + Assert.assertEquals( + UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString("数")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString("头")); + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 644def7501dc8..2aeed98285aa8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -107,7 +107,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 002d7b6eaf498..83dafa4a125d2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -46,6 +46,14 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -490,9 +498,11 @@ private[spark] class ApplicationMaster( new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here @@ -503,9 +513,7 @@ private[spark] class ApplicationMaster( val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index ae6dc1094d724..68e9f6b4db7f4 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userClass: String = null var primaryPyFile: String = null var primaryRFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value - args = tail - case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 234051eb7d3bb..67a5c95400e53 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,18 +17,21 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction -import java.util.UUID +import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files @@ -121,24 +124,31 @@ private[spark] class Client( } catch { case e: Throwable => if (appId != null) { - val appStagingDir = getAppStagingDir(appId) - try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) - val stagingDirPath = new Path(appStagingDir) - val fs = FileSystem.get(hadoopConf) - if (!preserveFiles && fs.exists(stagingDirPath)) { - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) - } + cleanupStagingDir(appId) } throw e } } + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + /** * Set up the context for submitting our ApplicationMaster. * This uses the YarnClientApplication not available in the Yarn alpha API. @@ -240,7 +250,9 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. @@ -270,20 +282,6 @@ private[spark] class Client( "for alternatives.") } - // If we passed in a keytab, make sure we copy the keytab to the staging directory on - // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { - logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + - " via the YARN Secure Distributed Cache.") - val localUri = new URI(args.keytab) - val localPath = getQualifiedLocalPath(localUri, hadoopConf) - val destinationPath = copyFileToRemote(dst, localPath, replication) - val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) - distCacheMgr.addResource( - destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, - sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -295,6 +293,57 @@ private[spark] class Client( } } + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(args.keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -307,33 +356,18 @@ private[spark] class Client( (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, _localPath, confKey) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } - } else if (confKey != null) { + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } - createConfArchive().foreach { file => - require(addDistributedUri(file.toURI())) - val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) - distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, - LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) - } - /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -349,21 +383,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -372,11 +395,31 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } /** - * Create an archive with the Hadoop config files for distribution. + * Create an archive with the config files for distribution. * * These are only used by the AM, since executors will use the configuration object broadcast by * the driver. The files are zipped and added to the job as an archive, so that YARN will explode @@ -388,8 +431,11 @@ private[spark] class Client( * * Currently this makes a shallow copy of the conf directory. If there are cases where a * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. */ - private def createConfArchive(): Option[File] = { + private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => @@ -404,28 +450,32 @@ private[spark] class Client( } } - if (!hadoopConfFiles.isEmpty) { - val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", - new File(Utils.getLocalDir(sparkConf))) + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) - val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) - try { - hadoopConfStream.setLevel(0) - hadoopConfFiles.foreach { case (name, file) => - if (file.canRead()) { - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() - } + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() } - } finally { - hadoopConfStream.close() } - Some(hadoopConfArchive) - } else { - None + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() } + confArchive } /** @@ -453,7 +503,9 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") @@ -471,9 +523,6 @@ private[spark] class Client( val renewalInterval = getTokenRenewalInterval(stagingDirPath) sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) } - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -490,15 +539,32 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } - // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH - // that can be passed on to the ApplicationMaster and the executors. - if (sparkConf.contains("spark.submit.pyArchives")) { - var pythonPath = sparkConf.get("spark.submit.pyArchives") - if (env.contains("PYTHONPATH")) { - pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator) + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() } - env("PYTHONPATH") = pythonPath - sparkConf.setExecutorEnv("PYTHONPATH", pythonPath) + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to @@ -548,8 +614,19 @@ private[spark] class Client( logInfo("Setting up container launch context for our AM") val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) amContainer.setEnvironment(launchEnv) @@ -589,13 +666,6 @@ private[spark] class Client( javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") @@ -606,7 +676,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -628,7 +698,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -648,14 +718,8 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) - } else { - Nil - } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } @@ -671,9 +735,6 @@ private[spark] class Client( } else { Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs - } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } @@ -681,11 +742,13 @@ private[spark] class Client( Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -764,6 +827,9 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState @@ -782,6 +848,7 @@ private[spark] class Client( if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } @@ -849,6 +916,22 @@ private[spark] class Client( } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { @@ -899,8 +982,14 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" - // Subdirectory where the user's hadoop config files will be placed. - val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -1017,15 +1106,15 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_HADOOP_CONF_DIR, env) + LOCALIZED_CONF_DIR, env) } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { @@ -1036,12 +1125,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1070,16 +1161,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @parma conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1093,6 +1186,29 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 9c7b1b3988082..35e990602a6cf 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() @@ -228,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 9d04d241dae9e..78e27fb7f3337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -146,7 +146,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +195,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } @@ -303,8 +303,8 @@ class ExecutorRunnable( val address = container.getNodeHttpAddress val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" - env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" - env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 99c05329b4d73..1c8d7ec57635f 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -76,7 +76,8 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,7 +87,7 @@ private[spark] class YarnClientSchedulerBackend( optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 1ace1a97d5156..33f580aaebdc0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -115,8 +115,9 @@ private[spark] class YarnClusterSchedulerBackend( val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some( - Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0")) + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } } } catch { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 01d33c9ce9297..837f8d3fa55a7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -113,7 +113,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { Environment.PWD.$() } cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -129,7 +129,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's @@ -151,6 +151,25 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 93d587d0cb36a..335e966519c7c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -56,6 +56,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher """.stripMargin private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | @@ -67,7 +68,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc = SparkContext(conf=SparkConf()) | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -76,6 +77,11 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc.stop() """.stripMargin + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ @@ -124,7 +130,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) assert(hadoopConfDir.mkdir()) File.createTempFile("token", ".txt", hadoopConfDir) } @@ -151,26 +157,12 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } } - // Enable this once fix SPARK-6700 - test("run Python application in yarn-cluster mode") { - val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, UTF_8) - var result = File.createTempFile("result", null, tempDir) + test("run Python application in yarn-client mode") { + testPySpark(true) + } - // The sbt assembly does not include pyspark / py4j python dependencies, so we need to - // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. - val sparkHome = sys.props("spark.test.home") - val extraConf = Map( - "spark.executorEnv.SPARK_HOME" -> sparkHome, - "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - - runSpark(false, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), - appArgs = Seq(result.getAbsolutePath()), - extraConf = extraConf) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } test("user class path first in client mode") { @@ -188,6 +180,33 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(result) } + private def testPySpark(clientMode: Boolean): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) @@ -357,7 +376,7 @@ private object YarnClusterDriver extends Logging with Matchers { new URL(urlStr) val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() - assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0")) + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) } }