diff --git a/R/pkg/.lintr b/R/pkg/.lintr
index b10ebd35c4ca7..038236fc149e6 100644
--- a/R/pkg/.lintr
+++ b/R/pkg/.lintr
@@ -1,2 +1,2 @@
-linters: with_defaults(line_length_linter(100), camel_case_linter = NULL)
+linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE))
exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 0af5cb8881e35..6feabf4189c2d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -38,7 +38,7 @@ setClass("DataFrame",
setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) {
.Object@env <- new.env()
.Object@env$isCached <- isCached
-
+
.Object@sdf <- sdf
.Object
})
@@ -55,11 +55,11 @@ dataFrame <- function(sdf, isCached = FALSE) {
############################ DataFrame Methods ##############################################
#' Print Schema of a DataFrame
-#'
+#'
#' Prints out the schema in tree format
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname printSchema
#' @export
#' @examples
@@ -78,11 +78,11 @@ setMethod("printSchema",
})
#' Get schema object
-#'
+#'
#' Returns the schema of this DataFrame as a structType object.
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname schema
#' @export
#' @examples
@@ -100,9 +100,9 @@ setMethod("schema",
})
#' Explain
-#'
+#'
#' Print the logical and physical Catalyst plans to the console for debugging.
-#'
+#'
#' @param x A SparkSQL DataFrame
#' @param extended Logical. If extended is False, explain() only prints the physical plan.
#' @rdname explain
@@ -200,11 +200,11 @@ setMethod("show", "DataFrame",
})
#' DataTypes
-#'
+#'
#' Return all column names and their data types as a list
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname dtypes
#' @export
#' @examples
@@ -224,11 +224,11 @@ setMethod("dtypes",
})
#' Column names
-#'
+#'
#' Return all column names as a list
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname columns
#' @export
#' @examples
@@ -256,12 +256,12 @@ setMethod("names",
})
#' Register Temporary Table
-#'
+#'
#' Registers a DataFrame as a Temporary Table in the SQLContext
-#'
+#'
#' @param x A SparkSQL DataFrame
#' @param tableName A character vector containing the name of the table
-#'
+#'
#' @rdname registerTempTable
#' @export
#' @examples
@@ -306,11 +306,11 @@ setMethod("insertInto",
})
#' Cache
-#'
+#'
#' Persist with the default storage level (MEMORY_ONLY).
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname cache-methods
#' @export
#' @examples
@@ -400,7 +400,7 @@ setMethod("repartition",
signature(x = "DataFrame", numPartitions = "numeric"),
function(x, numPartitions) {
sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions))
- dataFrame(sdf)
+ dataFrame(sdf)
})
# toJSON
@@ -489,7 +489,7 @@ setMethod("distinct",
#' sqlContext <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlContext, path)
-#' collect(sample(df, FALSE, 0.5))
+#' collect(sample(df, FALSE, 0.5))
#' collect(sample(df, TRUE, 0.5))
#'}
setMethod("sample",
@@ -513,11 +513,11 @@ setMethod("sample_frac",
})
#' Count
-#'
+#'
#' Returns the number of rows in a DataFrame
-#'
+#'
#' @param x A SparkSQL DataFrame
-#'
+#'
#' @rdname count
#' @export
#' @examples
@@ -568,13 +568,13 @@ setMethod("collect",
})
#' Limit
-#'
+#'
#' Limit the resulting DataFrame to the number of rows specified.
-#'
+#'
#' @param x A SparkSQL DataFrame
#' @param num The number of rows to return
#' @return A new DataFrame containing the number of rows specified.
-#'
+#'
#' @rdname limit
#' @export
#' @examples
@@ -593,7 +593,7 @@ setMethod("limit",
})
#' Take the first NUM rows of a DataFrame and return a the results as a data.frame
-#'
+#'
#' @rdname take
#' @export
#' @examples
@@ -613,8 +613,8 @@ setMethod("take",
#' Head
#'
-#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL,
-#' then head() returns the first 6 rows in keeping with the current data.frame
+#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL,
+#' then head() returns the first 6 rows in keeping with the current data.frame
#' convention in R.
#'
#' @param x A SparkSQL DataFrame
@@ -659,11 +659,11 @@ setMethod("first",
})
# toRDD()
-#
+#
# Converts a Spark DataFrame to an RDD while preserving column names.
-#
+#
# @param x A Spark DataFrame
-#
+#
# @rdname DataFrame
# @export
# @examples
@@ -1167,7 +1167,7 @@ setMethod("where",
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame
-#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a
+#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a
#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join
#' @param joinType The type of join to perform. The following join types are available:
#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner".
@@ -1303,7 +1303,7 @@ setMethod("except",
#' @param source A name for external data source
#' @param mode One of 'append', 'overwrite', 'error', 'ignore'
#'
-#' @rdname write.df
+#' @rdname write.df
#' @export
#' @examples
#'\dontrun{
@@ -1401,7 +1401,7 @@ setMethod("saveAsTable",
#' @param col A string of name
#' @param ... Additional expressions
#' @return A DataFrame
-#' @rdname describe
+#' @rdname describe
#' @export
#' @examples
#'\dontrun{
@@ -1444,7 +1444,7 @@ setMethod("describe",
#' This overwrites the how parameter.
#' @param cols Optional list of column names to consider.
#' @return A DataFrame
-#'
+#'
#' @rdname nafunctions
#' @export
#' @examples
@@ -1465,7 +1465,7 @@ setMethod("dropna",
if (is.null(minNonNulls)) {
minNonNulls <- if (how == "any") { length(cols) } else { 1 }
}
-
+
naFunctions <- callJMethod(x@sdf, "na")
sdf <- callJMethod(naFunctions, "drop",
as.integer(minNonNulls), listToSeq(as.list(cols)))
@@ -1488,16 +1488,16 @@ setMethod("na.omit",
#' @param value Value to replace null values with.
#' Should be an integer, numeric, character or named list.
#' If the value is a named list, then cols is ignored and
-#' value must be a mapping from column name (character) to
+#' value must be a mapping from column name (character) to
#' replacement value. The replacement value must be an
#' integer, numeric or character.
#' @param cols optional list of column names to consider.
#' Columns specified in cols that do not have matching data
-#' type are ignored. For example, if value is a character, and
+#' type are ignored. For example, if value is a character, and
#' subset contains a non-character column, then the non-character
#' column is simply ignored.
#' @return A DataFrame
-#'
+#'
#' @rdname nafunctions
#' @export
#' @examples
@@ -1515,14 +1515,14 @@ setMethod("fillna",
if (!(class(value) %in% c("integer", "numeric", "character", "list"))) {
stop("value should be an integer, numeric, charactor or named list.")
}
-
+
if (class(value) == "list") {
# Check column names in the named list
colNames <- names(value)
if (length(colNames) == 0 || !all(colNames != "")) {
stop("value should be an a named list with each name being a column name.")
}
-
+
# Convert to the named list to an environment to be passed to JVM
valueMap <- new.env()
for (col in colNames) {
@@ -1533,19 +1533,19 @@ setMethod("fillna",
}
valueMap[[col]] <- v
}
-
+
# When value is a named list, caller is expected not to pass in cols
if (!is.null(cols)) {
warning("When value is a named list, cols is ignored!")
cols <- NULL
}
-
+
value <- valueMap
} else if (is.integer(value)) {
# Cast an integer to a numeric
value <- as.numeric(value)
}
-
+
naFunctions <- callJMethod(x@sdf, "na")
sdf <- if (length(cols) == 0) {
callJMethod(naFunctions, "fill", value)
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 0513299515644..89511141d3ef7 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -48,7 +48,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode,
# byte: The RDD stores data serialized in R.
# string: The RDD stores data as strings.
# row: The RDD stores the serialized rows of a DataFrame.
-
+
# We use an environment to store mutable states inside an RDD object.
# Note that R's call-by-value semantics makes modifying slots inside an
# object (passed as an argument into a function, such as cache()) difficult:
@@ -363,7 +363,7 @@ setMethod("collectPartition",
# @description
# \code{collectAsMap} returns a named list as a map that contains all of the elements
-# in a key-value pair RDD.
+# in a key-value pair RDD.
# @examples
#\dontrun{
# sc <- sparkR.init()
@@ -666,7 +666,7 @@ setMethod("minimum",
# rdd <- parallelize(sc, 1:10)
# sumRDD(rdd) # 55
#}
-# @rdname sumRDD
+# @rdname sumRDD
# @aliases sumRDD,RDD
setMethod("sumRDD",
signature(x = "RDD"),
@@ -1090,11 +1090,11 @@ setMethod("sortBy",
# Return:
# A list of the first N elements from the RDD in the specified order.
#
-takeOrderedElem <- function(x, num, ascending = TRUE) {
+takeOrderedElem <- function(x, num, ascending = TRUE) {
if (num <= 0L) {
return(list())
}
-
+
partitionFunc <- function(part) {
if (num < length(part)) {
# R limitation: order works only on primitive types!
@@ -1152,7 +1152,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) {
# @aliases takeOrdered,RDD,RDD-method
setMethod("takeOrdered",
signature(x = "RDD", num = "integer"),
- function(x, num) {
+ function(x, num) {
takeOrderedElem(x, num)
})
@@ -1173,7 +1173,7 @@ setMethod("takeOrdered",
# @aliases top,RDD,RDD-method
setMethod("top",
signature(x = "RDD", num = "integer"),
- function(x, num) {
+ function(x, num) {
takeOrderedElem(x, num, FALSE)
})
@@ -1181,7 +1181,7 @@ setMethod("top",
#
# Aggregate the elements of each partition, and then the results for all the
# partitions, using a given associative function and a neutral "zero value".
-#
+#
# @param x An RDD.
# @param zeroValue A neutral "zero value".
# @param op An associative function for the folding operation.
@@ -1207,7 +1207,7 @@ setMethod("fold",
#
# Aggregate the elements of each partition, and then the results for all the
# partitions, using given combine functions and a neutral "zero value".
-#
+#
# @param x An RDD.
# @param zeroValue A neutral "zero value".
# @param seqOp A function to aggregate the RDD elements. It may return a different
@@ -1230,11 +1230,11 @@ setMethod("fold",
# @aliases aggregateRDD,RDD,RDD-method
setMethod("aggregateRDD",
signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"),
- function(x, zeroValue, seqOp, combOp) {
+ function(x, zeroValue, seqOp, combOp) {
partitionFunc <- function(part) {
Reduce(seqOp, part, zeroValue)
}
-
+
partitionList <- collect(lapplyPartition(x, partitionFunc),
flatten = FALSE)
Reduce(combOp, partitionList, zeroValue)
@@ -1330,7 +1330,7 @@ setMethod("setName",
#\dontrun{
# sc <- sparkR.init()
# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
-# collect(zipWithUniqueId(rdd))
+# collect(zipWithUniqueId(rdd))
# # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2))
#}
# @rdname zipWithUniqueId
@@ -1426,7 +1426,7 @@ setMethod("glom",
partitionFunc <- function(part) {
list(part)
}
-
+
lapplyPartition(x, partitionFunc)
})
@@ -1498,16 +1498,16 @@ setMethod("zipRDD",
# The jrdd's elements are of scala Tuple2 type. The serialized
# flag here is used for the elements inside the tuples.
rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
-
+
mergePartitions(rdd, TRUE)
})
# Cartesian product of this RDD and another one.
#
-# Return the Cartesian product of this RDD and another one,
-# that is, the RDD of all pairs of elements (a, b) where a
+# Return the Cartesian product of this RDD and another one,
+# that is, the RDD of all pairs of elements (a, b) where a
# is in this and b is in other.
-#
+#
# @param x An RDD.
# @param other An RDD.
# @return A new RDD which is the Cartesian product of these two RDDs.
@@ -1515,7 +1515,7 @@ setMethod("zipRDD",
#\dontrun{
# sc <- sparkR.init()
# rdd <- parallelize(sc, 1:2)
-# sortByKey(cartesian(rdd, rdd))
+# sortByKey(cartesian(rdd, rdd))
# # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2))
#}
# @rdname cartesian
@@ -1528,7 +1528,7 @@ setMethod("cartesian",
# The jrdd's elements are of scala Tuple2 type. The serialized
# flag here is used for the elements inside the tuples.
rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
-
+
mergePartitions(rdd, FALSE)
})
@@ -1598,11 +1598,11 @@ setMethod("intersection",
# Zips an RDD's partitions with one (or more) RDD(s).
# Same as zipPartitions in Spark.
-#
+#
# @param ... RDDs to be zipped.
# @param func A function to transform zipped partitions.
-# @return A new RDD by applying a function to the zipped partitions.
-# Assumes that all the RDDs have the *same number of partitions*, but
+# @return A new RDD by applying a function to the zipped partitions.
+# Assumes that all the RDDs have the *same number of partitions*, but
# does *not* require them to have the same number of elements in each partition.
# @examples
#\dontrun{
@@ -1610,7 +1610,7 @@ setMethod("intersection",
# rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
# rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
# rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
-# collect(zipPartitions(rdd1, rdd2, rdd3,
+# collect(zipPartitions(rdd1, rdd2, rdd3,
# func = function(x, y, z) { list(list(x, y, z))} ))
# # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))
#}
@@ -1627,7 +1627,7 @@ setMethod("zipPartitions",
if (length(unique(nPart)) != 1) {
stop("Can only zipPartitions RDDs which have the same number of partitions.")
}
-
+
rrdds <- lapply(rrdds, function(rdd) {
mapPartitionsWithIndex(rdd, function(partIndex, part) {
print(length(part))
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 22a4b5bf86ebd..9a743a3411533 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -182,7 +182,7 @@ setMethod("toDF", signature(x = "RDD"),
#' Create a DataFrame from a JSON file.
#'
-#' Loads a JSON file (one object per line), returning the result as a DataFrame
+#' Loads a JSON file (one object per line), returning the result as a DataFrame
#' It goes through the entire dataset once to determine the schema.
#'
#' @param sqlContext SQLContext to use
@@ -238,7 +238,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) {
#' Create a DataFrame from a Parquet file.
-#'
+#'
#' Loads a Parquet file, returning the result as a DataFrame.
#'
#' @param sqlContext SQLContext to use
@@ -278,7 +278,7 @@ sql <- function(sqlContext, sqlQuery) {
}
#' Create a DataFrame from a SparkSQL Table
-#'
+#'
#' Returns the specified Table as a DataFrame. The Table must have already been registered
#' in the SQLContext.
#'
@@ -298,7 +298,7 @@ sql <- function(sqlContext, sqlQuery) {
table <- function(sqlContext, tableName) {
sdf <- callJMethod(sqlContext, "table", tableName)
- dataFrame(sdf)
+ dataFrame(sdf)
}
@@ -352,7 +352,7 @@ tableNames <- function(sqlContext, databaseName = NULL) {
#' Cache Table
-#'
+#'
#' Caches the specified table in-memory.
#'
#' @param sqlContext SQLContext to use
@@ -370,11 +370,11 @@ tableNames <- function(sqlContext, databaseName = NULL) {
#' }
cacheTable <- function(sqlContext, tableName) {
- callJMethod(sqlContext, "cacheTable", tableName)
+ callJMethod(sqlContext, "cacheTable", tableName)
}
#' Uncache Table
-#'
+#'
#' Removes the specified table from the in-memory cache.
#'
#' @param sqlContext SQLContext to use
diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R
index 23dc38780716e..2403925b267c8 100644
--- a/R/pkg/R/broadcast.R
+++ b/R/pkg/R/broadcast.R
@@ -27,9 +27,9 @@
# @description Broadcast variables can be created using the broadcast
# function from a \code{SparkContext}.
# @rdname broadcast-class
-# @seealso broadcast
+# @seealso broadcast
#
-# @param id Id of the backing Spark broadcast variable
+# @param id Id of the backing Spark broadcast variable
# @export
setClass("Broadcast", slots = list(id = "character"))
@@ -68,7 +68,7 @@ setMethod("value",
# variable on workers. Not intended for use outside the package.
#
# @rdname broadcast-internal
-# @seealso broadcast, value
+# @seealso broadcast, value
# @param bcastId The id of broadcast variable to set
# @param value The value to be set
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 80e92d3105a36..8e4b0f5bf1c4d 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -210,6 +210,22 @@ setMethod("cast",
}
})
+#' Match a column with given values.
+#'
+#' @rdname column
+#' @return a matched values as a result of comparing with given values.
+#' \dontrun{
+#' filter(df, "age in (10, 30)")
+#' where(df, df$age %in% c(10, 30))
+#' }
+setMethod("%in%",
+ signature(x = "Column"),
+ function(x, table) {
+ table <- listToSeq(as.list(table))
+ jc <- callJMethod(x@jc, "in", table)
+ return(column(jc))
+ })
+
#' Approx Count Distinct
#'
#' @rdname column
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 257b435607ce8..d961bbc383688 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -18,7 +18,7 @@
# Utility functions to deserialize objects from Java.
# Type mapping from Java to R
-#
+#
# void -> NULL
# Int -> integer
# String -> character
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 12e09176c9f92..79055b7f18558 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -130,7 +130,7 @@ setGeneric("maximum", function(x) { standardGeneric("maximum") })
# @export
setGeneric("minimum", function(x) { standardGeneric("minimum") })
-# @rdname sumRDD
+# @rdname sumRDD
# @export
setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") })
@@ -219,7 +219,7 @@ setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") })
# @rdname zipRDD
# @export
-setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") },
+setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") },
signature = "...")
# @rdname zipWithIndex
@@ -364,7 +364,7 @@ setGeneric("subtract",
# @rdname subtractByKey
# @export
-setGeneric("subtractByKey",
+setGeneric("subtractByKey",
function(x, other, numPartitions = 1) {
standardGeneric("subtractByKey")
})
@@ -399,15 +399,15 @@ setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
#' @rdname nafunctions
#' @export
setGeneric("dropna",
- function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
- standardGeneric("dropna")
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("dropna")
})
#' @rdname nafunctions
#' @export
setGeneric("na.omit",
- function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
- standardGeneric("na.omit")
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("na.omit")
})
#' @rdname schema
@@ -656,4 +656,3 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
-
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index b758481997574..8f1c68f7c4d28 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -136,4 +136,3 @@ createMethods <- function() {
}
createMethods()
-
diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R
index a8a25230b636d..0838a7bb35e0d 100644
--- a/R/pkg/R/jobj.R
+++ b/R/pkg/R/jobj.R
@@ -16,7 +16,7 @@
#
# References to objects that exist on the JVM backend
-# are maintained using the jobj.
+# are maintained using the jobj.
#' @include generics.R
NULL
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 1e24286dbcae2..7f902ba8e683e 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -784,7 +784,7 @@ setMethod("sortByKey",
newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
lapplyPartition(newRDD, partitionFunc)
})
-
+
# Subtract a pair RDD with another pair RDD.
#
# Return an RDD with the pairs from x whose keys are not in other.
@@ -820,7 +820,7 @@ setMethod("subtractByKey",
})
# Return a subset of this RDD sampled by key.
-#
+#
# @description
# \code{sampleByKey} Create a sample of this RDD using variable sampling rates
# for different keys as specified by fractions, a key to sampling rate map.
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index e442119086b17..15e2bdbd55d79 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -20,7 +20,7 @@
#' structType
#'
-#' Create a structType object that contains the metadata for a DataFrame. Intended for
+#' Create a structType object that contains the metadata for a DataFrame. Intended for
#' use with createDataFrame and toDF.
#'
#' @param x a structField object (created with the field() function)
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 3169d7968f8fe..78535eff0d2f6 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -175,7 +175,7 @@ writeGenericList <- function(con, list) {
writeObject(con, elem)
}
}
-
+
# Used to pass in hash maps required on Java side.
writeEnv <- function(con, env) {
len <- length(env)
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 2efd4f0742e77..dbde0c44c55d5 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -43,7 +43,7 @@ sparkR.stop <- function() {
callJMethod(sc, "stop")
rm(".sparkRjsc", envir = env)
}
-
+
if (exists(".backendLaunched", envir = env)) {
callJStatic("SparkRHandler", "stopBackend")
}
@@ -174,7 +174,7 @@ sparkR.init <- function(
for (varname in names(sparkEnvir)) {
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
}
-
+
sparkExecutorEnvMap <- new.env()
if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
@@ -214,7 +214,7 @@ sparkR.init <- function(
#' Initialize a new SQLContext.
#'
-#' This function creates a SparkContext from an existing JavaSparkContext and
+#' This function creates a SparkContext from an existing JavaSparkContext and
#' then uses it to initialize a new SQLContext
#'
#' @param jsc The existing JavaSparkContext created with SparkR.init()
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 69b2700191c9a..13cec0f712fb4 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -368,21 +368,21 @@ listToSeq <- function(l) {
}
# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
-# user defined function (UDF), and to examine variables in the UDF to decide
+# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
# param
# node The current AST node in the traversal.
# oldEnv The original function environment.
# defVars An Accumulator of variables names defined in the function's calling environment,
# including function argument and local variable names.
-# checkedFunc An environment of function objects examined during cleanClosure. It can
+# checkedFunc An environment of function objects examined during cleanClosure. It can
# be considered as a "name"-to-"list of functions" mapping.
# newEnv A new function environment to store necessary function dependencies, an output argument.
processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
nodeLen <- length(node)
-
+
if (nodeLen > 1 && typeof(node) == "language") {
- # Recursive case: current AST node is an internal node, check for its children.
+ # Recursive case: current AST node is an internal node, check for its children.
if (length(node[[1]]) > 1) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
@@ -393,7 +393,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else if (nodeChar == "<-" || nodeChar == "=" ||
+ } else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") { # Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
@@ -422,21 +422,21 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
}
}
}
- } else if (nodeLen == 1 &&
+ } else if (nodeLen == 1 &&
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
- # Search in function environment, and function's enclosing environments
+ # Search in function environment, and function's enclosing environments
# up to global environment. There is no need to look into package environments
- # above the global or namespace environment that is not SparkR below the global,
+ # above the global or namespace environment that is not SparkR below the global,
# as they are assumed to be loaded on workers.
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
- if (!isNamespace(func.env) ||
- (getNamespaceName(func.env) == "SparkR" &&
+ if (!isNamespace(func.env) ||
+ (getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
@@ -444,7 +444,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
if (is.function(obj)) { # If the node is a function call.
- funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
+ funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
@@ -453,7 +453,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
break
}
# Function has not been examined, record it and recursively clean its closure.
- assign(nodeChar,
+ assign(nodeChar,
if (is.null(funcList[[1]])) {
list(obj)
} else {
@@ -466,7 +466,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
break
}
}
-
+
# Continue to search in enclosure.
func.env <- parent.env(func.env)
}
@@ -474,8 +474,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
}
}
-# Utility function to get user defined function (UDF) dependencies (closure).
-# More specifically, this function captures the values of free variables defined
+# Utility function to get user defined function (UDF) dependencies (closure).
+# More specifically, this function captures the values of free variables defined
# outside a UDF, and stores them in the function's environment.
# param
# func A function whose closure needs to be captured.
@@ -488,7 +488,7 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
newEnv <- new.env(parent = .GlobalEnv)
func.body <- body(func)
oldEnv <- environment(func)
- # defVars is an Accumulator of variables names defined in the function's calling
+ # defVars is an Accumulator of variables names defined in the function's calling
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
@@ -509,15 +509,15 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
# return value
# A list of two result RDDs.
appendPartitionLengths <- function(x, other) {
- if (getSerializedMode(x) != getSerializedMode(other) ||
+ if (getSerializedMode(x) != getSerializedMode(other) ||
getSerializedMode(x) == "byte") {
# Append the number of elements in each partition to that partition so that we can later
# know the boundary of elements from x and other.
#
- # Note that this appending also serves the purpose of reserialization, because even if
+ # Note that this appending also serves the purpose of reserialization, because even if
# any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
# as a single byte array. For example, partitions of an RDD generated from partitionBy()
- # may be encoded as multiple byte arrays.
+ # may be encoded as multiple byte arrays.
appendLength <- function(part) {
len <- length(part)
part[[len + 1]] <- len + 1
@@ -544,23 +544,23 @@ mergePartitions <- function(rdd, zip) {
lengthOfValues <- part[[len]]
lengthOfKeys <- part[[len - lengthOfValues]]
stopifnot(len == lengthOfKeys + lengthOfValues)
-
+
# For zip operation, check if corresponding partitions of both RDDs have the same number of elements.
if (zip && lengthOfKeys != lengthOfValues) {
stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
}
-
+
if (lengthOfKeys > 1) {
keys <- part[1 : (lengthOfKeys - 1)]
} else {
keys <- list()
}
if (lengthOfValues > 1) {
- values <- part[(lengthOfKeys + 1) : (len - 1)]
+ values <- part[(lengthOfKeys + 1) : (len - 1)]
} else {
values <- list()
}
-
+
if (!zip) {
return(mergeCompactLists(keys, values))
}
@@ -578,6 +578,6 @@ mergePartitions <- function(rdd, zip) {
part
}
}
-
+
PipelinedRDD(rdd, partitionFunc)
}
diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R
index 80d796d467943..301feade65fa3 100644
--- a/R/pkg/R/zzz.R
+++ b/R/pkg/R/zzz.R
@@ -18,4 +18,3 @@
.onLoad <- function(libname, pkgname) {
sparkR.onLoad(libname, pkgname)
}
-
diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R
index ca4218f3819f8..4db7266abc8e2 100644
--- a/R/pkg/inst/tests/test_binaryFile.R
+++ b/R/pkg/inst/tests/test_binaryFile.R
@@ -59,15 +59,15 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works",
wordCount <- lapply(words, function(word) { list(word, 1L) })
counts <- reduceByKey(wordCount, "+", 2L)
-
+
saveAsObjectFile(counts, fileName2)
counts <- objectFile(sc, fileName2)
-
+
output <- collect(counts)
expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1),
list("is", 2))
expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
-
+
unlink(fileName1)
unlink(fileName2, recursive = TRUE)
})
@@ -87,4 +87,3 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", {
unlink(fileName1, recursive = TRUE)
unlink(fileName2, recursive = TRUE)
})
-
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
index 6785a7bdae8cb..a1e354e567be5 100644
--- a/R/pkg/inst/tests/test_binary_function.R
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -30,7 +30,7 @@ mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("union on two RDDs", {
actual <- collect(unionRDD(rdd, rdd))
expect_equal(actual, as.list(rep(nums, 2)))
-
+
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
@@ -52,14 +52,14 @@ test_that("union on two RDDs", {
test_that("cogroup on two RDDs", {
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
- cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
+ cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
actual <- collect(cogroup.rdd)
- expect_equal(actual,
+ expect_equal(actual,
list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list()))))
-
+
rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4)))
rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3)))
- cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
+ cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
actual <- collect(cogroup.rdd)
expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3))))
@@ -71,31 +71,31 @@ test_that("zipPartitions() on RDDs", {
rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
- actual <- collect(zipPartitions(rdd1, rdd2, rdd3,
+ actual <- collect(zipPartitions(rdd1, rdd2, rdd3,
func = function(x, y, z) { list(list(x, y, z))} ))
expect_equal(actual,
list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))))
-
+
mockFile = c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
-
+
rdd <- textFile(sc, fileName, 1)
- actual <- collect(zipPartitions(rdd, rdd,
+ actual <- collect(zipPartitions(rdd, rdd,
func = function(x, y) { list(paste(x, y, sep = "\n")) }))
expected <- list(paste(mockFile, mockFile, sep = "\n"))
expect_equal(actual, expected)
-
+
rdd1 <- parallelize(sc, 0:1, 1)
- actual <- collect(zipPartitions(rdd1, rdd,
+ actual <- collect(zipPartitions(rdd1, rdd,
func = function(x, y) { list(x + nchar(y)) }))
expected <- list(0:1 + nchar(mockFile))
expect_equal(actual, expected)
-
+
rdd <- map(rdd, function(x) { x })
- actual <- collect(zipPartitions(rdd, rdd1,
+ actual <- collect(zipPartitions(rdd, rdd1,
func = function(x, y) { list(y + nchar(x)) }))
expect_equal(actual, expected)
-
+
unlink(fileName)
})
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index 03207353c31c6..4fe653856756e 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -477,7 +477,7 @@ test_that("cartesian() on RDDs", {
list(1, 1), list(1, 2), list(1, 3),
list(2, 1), list(2, 2), list(2, 3),
list(3, 1), list(3, 2), list(3, 3)))
-
+
# test case where one RDD is empty
emptyRdd <- parallelize(sc, list())
actual <- collect(cartesian(rdd, emptyRdd))
@@ -486,7 +486,7 @@ test_that("cartesian() on RDDs", {
mockFile = c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
-
+
rdd <- textFile(sc, fileName)
actual <- collect(cartesian(rdd, rdd))
expected <- list(
@@ -495,7 +495,7 @@ test_that("cartesian() on RDDs", {
list("Spark is pretty.", "Spark is pretty."),
list("Spark is pretty.", "Spark is awesome."))
expect_equal(sortKeyValueList(actual), expected)
-
+
rdd1 <- parallelize(sc, 0:1)
actual <- collect(cartesian(rdd1, rdd))
expect_equal(sortKeyValueList(actual),
@@ -504,11 +504,11 @@ test_that("cartesian() on RDDs", {
list(0, "Spark is awesome."),
list(1, "Spark is pretty."),
list(1, "Spark is awesome.")))
-
+
rdd1 <- map(rdd, function(x) { x })
actual <- collect(cartesian(rdd, rdd1))
expect_equal(sortKeyValueList(actual), expected)
-
+
unlink(fileName)
})
@@ -760,7 +760,7 @@ test_that("collectAsMap() on a pairwise RDD", {
})
test_that("show()", {
- rdd <- parallelize(sc, list(1:10))
+ rdd <- parallelize(sc, list(1:10))
expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+")
})
diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R
index d7dedda553c56..adf0b91d25fe9 100644
--- a/R/pkg/inst/tests/test_shuffle.R
+++ b/R/pkg/inst/tests/test_shuffle.R
@@ -106,39 +106,39 @@ test_that("aggregateByKey", {
zeroValue <- list(0, 0)
seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
- aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
-
+ aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+
actual <- collect(aggregatedRDD)
-
+
expected <- list(list(1, list(3, 2)), list(2, list(7, 2)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
# test aggregateByKey for string keys
rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4)))
-
+
zeroValue <- list(0, 0)
seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
- aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+ aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
actual <- collect(aggregatedRDD)
-
+
expected <- list(list("a", list(3, 2)), list("b", list(7, 2)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
})
-test_that("foldByKey", {
+test_that("foldByKey", {
# test foldByKey for int keys
folded <- foldByKey(intRdd, 0, "+", 2L)
-
+
actual <- collect(folded)
-
+
expected <- list(list(2L, 101), list(1L, 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
# test foldByKey for double keys
folded <- foldByKey(doubleRdd, 0, "+", 2L)
-
+
actual <- collect(folded)
expected <- list(list(1.5, 199), list(2.5, 101))
@@ -146,15 +146,15 @@ test_that("foldByKey", {
# test foldByKey for string keys
stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200))
-
+
stringKeyRDD <- parallelize(sc, stringKeyPairs)
folded <- foldByKey(stringKeyRDD, 0, "+", 2L)
-
+
actual <- collect(folded)
-
+
expected <- list(list("b", 101), list("a", 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
-
+
# test foldByKey for empty pair RDD
rdd <- parallelize(sc, list())
folded <- foldByKey(rdd, 0, "+", 2L)
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 8946348ef801c..417153dc0985c 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -67,7 +67,7 @@ test_that("structType and structField", {
expect_true(inherits(testField, "structField"))
expect_true(testField$name() == "a")
expect_true(testField$nullable())
-
+
testSchema <- structType(testField, structField("b", "integer"))
expect_true(inherits(testSchema, "structType"))
expect_true(inherits(testSchema$fields()[[2]], "structField"))
@@ -598,7 +598,7 @@ test_that("column functions", {
c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
c5 <- n(c) + n_distinct(c)
- c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
+ c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c)
c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
@@ -693,6 +693,16 @@ test_that("filter() on a DataFrame", {
filtered2 <- where(df, df$name != "Michael")
expect_true(count(filtered2) == 2)
expect_true(collect(filtered2)$age[2] == 19)
+
+ # test suites for %in%
+ filtered3 <- filter(df, "age in (19)")
+ expect_equal(count(filtered3), 1)
+ filtered4 <- filter(df, "age in (19, 30)")
+ expect_equal(count(filtered4), 2)
+ filtered5 <- where(df, df$age %in% c(19))
+ expect_equal(count(filtered5), 1)
+ filtered6 <- where(df, df$age %in% c(19, 30))
+ expect_equal(count(filtered6), 2)
})
test_that("join() on a DataFrame", {
@@ -829,7 +839,7 @@ test_that("dropna() on a DataFrame", {
rows <- collect(df)
# drop with columns
-
+
expected <- rows[!is.na(rows$name),]
actual <- collect(dropna(df, cols = "name"))
expect_true(identical(expected, actual))
@@ -842,7 +852,7 @@ test_that("dropna() on a DataFrame", {
expect_true(identical(expected$age, actual$age))
expect_true(identical(expected$height, actual$height))
expect_true(identical(expected$name, actual$name))
-
+
expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
actual <- collect(dropna(df, cols = c("age", "height")))
expect_true(identical(expected, actual))
@@ -850,7 +860,7 @@ test_that("dropna() on a DataFrame", {
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
actual <- collect(dropna(df))
expect_true(identical(expected, actual))
-
+
# drop with how
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
@@ -860,7 +870,7 @@ test_that("dropna() on a DataFrame", {
expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),]
actual <- collect(dropna(df, "all"))
expect_true(identical(expected, actual))
-
+
expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
actual <- collect(dropna(df, "any"))
expect_true(identical(expected, actual))
@@ -872,14 +882,14 @@ test_that("dropna() on a DataFrame", {
expected <- rows[!is.na(rows$age) | !is.na(rows$height),]
actual <- collect(dropna(df, "all", cols = c("age", "height")))
expect_true(identical(expected, actual))
-
+
# drop with threshold
-
+
expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,]
actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height")))
- expect_true(identical(expected, actual))
+ expect_true(identical(expected, actual))
- expected <- rows[as.integer(!is.na(rows$age)) +
+ expected <- rows[as.integer(!is.na(rows$age)) +
as.integer(!is.na(rows$height)) +
as.integer(!is.na(rows$name)) >= 3,]
actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height")))
@@ -889,9 +899,9 @@ test_that("dropna() on a DataFrame", {
test_that("fillna() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPathNa)
rows <- collect(df)
-
+
# fill with value
-
+
expected <- rows
expected$age[is.na(expected$age)] <- 50
expected$height[is.na(expected$height)] <- 50.6
@@ -912,7 +922,7 @@ test_that("fillna() on a DataFrame", {
expected$name[is.na(expected$name)] <- "unknown"
actual <- collect(fillna(df, "unknown", c("age", "name")))
expect_true(identical(expected, actual))
-
+
# fill with named list
expected <- rows
@@ -920,7 +930,7 @@ test_that("fillna() on a DataFrame", {
expected$height[is.na(expected$height)] <- 50.6
expected$name[is.na(expected$name)] <- "unknown"
actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown")))
- expect_true(identical(expected, actual))
+ expect_true(identical(expected, actual))
})
unlink(parquetPath)
diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R
index 7f4c7c315d787..c5eb417b40159 100644
--- a/R/pkg/inst/tests/test_take.R
+++ b/R/pkg/inst/tests/test_take.R
@@ -64,4 +64,3 @@ test_that("take() gives back the original elements in correct count and order",
expect_true(length(take(numListRDD, 0)) == 0)
expect_true(length(take(numVectorRDD, 0)) == 0)
})
-
diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R
index 6b87b4b3e0b08..092ad9dc10c2e 100644
--- a/R/pkg/inst/tests/test_textFile.R
+++ b/R/pkg/inst/tests/test_textFile.R
@@ -58,7 +58,7 @@ test_that("textFile() word count works as expected", {
expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1),
list("Spark", 2))
expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
-
+
unlink(fileName)
})
@@ -115,13 +115,13 @@ test_that("textFile() and saveAsTextFile() word count works as expected", {
saveAsTextFile(counts, fileName2)
rdd <- textFile(sc, fileName2)
-
+
output <- collect(rdd)
expected <- list(list("awesome.", 1), list("Spark", 2),
list("pretty.", 1), list("is", 2))
expectedStr <- lapply(expected, function(x) { toString(x) })
expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr))
-
+
unlink(fileName1)
unlink(fileName2)
})
@@ -159,4 +159,3 @@ test_that("Pipelined operations on RDDs created using textFile", {
unlink(fileName)
})
-
diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R
index 539e3a3c19df3..15030e6f1d77e 100644
--- a/R/pkg/inst/tests/test_utils.R
+++ b/R/pkg/inst/tests/test_utils.R
@@ -43,13 +43,13 @@ test_that("serializeToBytes on RDD", {
mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
-
+
text.rdd <- textFile(sc, fileName)
expect_true(getSerializedMode(text.rdd) == "string")
ser.rdd <- serializeToBytes(text.rdd)
expect_equal(collect(ser.rdd), as.list(mockFile))
expect_true(getSerializedMode(ser.rdd) == "byte")
-
+
unlink(fileName)
})
@@ -64,7 +64,7 @@ test_that("cleanClosure on R functions", {
expect_equal(actual, y)
actual <- get("g", envir = env, inherits = FALSE)
expect_equal(actual, g)
-
+
# Test for nested enclosures and package variables.
env2 <- new.env()
funcEnv <- new.env(parent = env2)
@@ -106,7 +106,7 @@ test_that("cleanClosure on R functions", {
expect_equal(length(ls(env)), 1)
actual <- get("y", envir = env, inherits = FALSE)
expect_equal(actual, y)
-
+
# Test for function (and variable) definitions.
f <- function(x) {
g <- function(y) { y * 2 }
@@ -115,7 +115,7 @@ test_that("cleanClosure on R functions", {
newF <- cleanClosure(f)
env <- environment(newF)
expect_equal(length(ls(env)), 0) # "y" and "g" should not be included.
-
+
# Test for overriding variables in base namespace (Issue: SparkR-196).
nums <- as.list(1:10)
rdd <- parallelize(sc, nums, 2L)
@@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", {
actual <- collect(lapply(rdd, f))
expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6)))
expect_equal(actual, expected)
-
+
# Test for broadcast variables.
a <- matrix(nrow=10, ncol=10, data=rnorm(100))
aBroadcast <- broadcast(sc, a)
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 2cccfed75edee..de1b4537eda5f 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -179,14 +179,14 @@ def contains_file(self, filename):
)
-streaming_mqqt = Module(
- name="streaming-mqqt",
+streaming_mqtt = Module(
+ name="streaming-mqtt",
dependencies=[streaming],
source_file_regexes=[
- "external/mqqt",
+ "external/mqtt",
],
sbt_test_goals=[
- "streaming-mqqt/test",
+ "streaming-mqtt/test",
]
)
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 56087499464e0..63e2c79669763 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -289,6 +289,10 @@ def parse_args():
parser.add_option(
"--additional-security-group", type="string", default="",
help="Additional security group to place the machines in")
+ parser.add_option(
+ "--additional-tags", type="string", default="",
+ help="Additional tags to set on the machines; tags are comma-separated, while name and " +
+ "value are colon separated; ex: \"Task:MySparkProject,Env:production\"")
parser.add_option(
"--copy-aws-credentials", action="store_true", default=False,
help="Add AWS credentials to hadoop configuration to allow Spark to access S3")
@@ -358,7 +362,7 @@ def get_validate_spark_version(version, repo):
# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/
-# Last Updated: 2015-05-08
+# Last Updated: 2015-06-19
# For easy maintainability, please keep this manually-inputted dictionary sorted by key.
EC2_INSTANCE_TYPES = {
"c1.medium": "pvm",
@@ -400,6 +404,11 @@ def get_validate_spark_version(version, repo):
"m3.large": "hvm",
"m3.xlarge": "hvm",
"m3.2xlarge": "hvm",
+ "m4.large": "hvm",
+ "m4.xlarge": "hvm",
+ "m4.2xlarge": "hvm",
+ "m4.4xlarge": "hvm",
+ "m4.10xlarge": "hvm",
"r3.large": "hvm",
"r3.xlarge": "hvm",
"r3.2xlarge": "hvm",
@@ -409,6 +418,7 @@ def get_validate_spark_version(version, repo):
"t2.micro": "hvm",
"t2.small": "hvm",
"t2.medium": "hvm",
+ "t2.large": "hvm",
}
@@ -684,16 +694,24 @@ def launch_cluster(conn, opts, cluster_name):
# This wait time corresponds to SPARK-4983
print("Waiting for AWS to propagate instance metadata...")
- time.sleep(5)
- # Give the instances descriptive names
+ time.sleep(15)
+
+ # Give the instances descriptive names and set additional tags
+ additional_tags = {}
+ if opts.additional_tags.strip():
+ additional_tags = dict(
+ map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',')
+ )
+
for master in master_nodes:
- master.add_tag(
- key='Name',
- value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
+ master.add_tags(
+ dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
+ )
+
for slave in slave_nodes:
- slave.add_tag(
- key='Name',
- value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
+ slave.add_tags(
+ dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
+ )
# Return all the instances
return (master_nodes, slave_nodes)
@@ -911,7 +929,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state):
# Get number of local disks available for a given EC2 instance type.
def get_num_disks(instance_type):
# Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html
- # Last Updated: 2015-05-08
+ # Last Updated: 2015-06-19
# For easy maintainability, please keep this manually-inputted dictionary sorted by key.
disks_by_instance = {
"c1.medium": 1,
@@ -953,6 +971,11 @@ def get_num_disks(instance_type):
"m3.large": 1,
"m3.xlarge": 2,
"m3.2xlarge": 2,
+ "m4.large": 0,
+ "m4.xlarge": 0,
+ "m4.2xlarge": 0,
+ "m4.4xlarge": 0,
+ "m4.10xlarge": 0,
"r3.large": 1,
"r3.xlarge": 1,
"r3.2xlarge": 1,
@@ -962,6 +985,7 @@ def get_num_disks(instance_type):
"t2.micro": 0,
"t2.small": 0,
"t2.medium": 0,
+ "t2.large": 0,
}
if instance_type in disks_by_instance:
return disks_by_instance[instance_type]
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 7a7dccc3d0922..0664cfb2021e1 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -35,10 +35,6 @@
http://spark.apache.org/
-
- org.apache.commons
- commons-lang3
-
org.apache.flume
flume-ng-sdk
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index dc2a4ab138e18..719fca0938b3a 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -16,13 +16,13 @@
*/
package org.apache.spark.streaming.flume.sink
+import java.util.UUID
import java.util.concurrent.{CountDownLatch, Executors}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
import org.apache.flume.Channel
-import org.apache.commons.lang3.RandomStringUtils
/**
* Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process
@@ -53,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
// Since the new txn may not have the same sequence number we must guard against accidentally
// committing a new transaction. To reduce the probability of that happening a random string is
// prepended to the sequence number. Does not change for life of sink
- private val seqBase = RandomStringUtils.randomAlphanumeric(8)
+ private val seqBase = UUID.randomUUID().toString.substring(0, 8)
private val seqCounter = new AtomicLong(0)
// Protected by `sequenceNumberToProcessor`
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
new file mode 100644
index 0000000000000..8de10eb51f923
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
+
+/**
+ * :: Experimental ::
+ * A feature transformer that converts the input array of strings into an array of n-grams. Null
+ * values in the input array are ignored.
+ * It returns an array of n-grams where each n-gram is represented by a space-separated string of
+ * words.
+ *
+ * When the input is empty, an empty array is returned.
+ * When the input array length is less than n (number of elements per n-gram), no n-grams are
+ * returned.
+ */
+@Experimental
+class NGram(override val uid: String)
+ extends UnaryTransformer[Seq[String], Seq[String], NGram] {
+
+ def this() = this(Identifiable.randomUID("ngram"))
+
+ /**
+ * Minimum n-gram length, >= 1.
+ * Default: 2, bigram features
+ * @group param
+ */
+ val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
+ ParamValidators.gtEq(1))
+
+ /** @group setParam */
+ def setN(value: Int): this.type = set(n, value)
+
+ /** @group getParam */
+ def getN: Int = $(n)
+
+ setDefault(n -> 2)
+
+ override protected def createTransformFunc: Seq[String] => Seq[String] = {
+ _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
+ }
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType.sameType(ArrayType(StringType)),
+ s"Input type must be ArrayType(StringType) but got $inputType.")
+ }
+
+ override protected def outputDataType: DataType = new ArrayType(StringType, false)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e73c14fd5c4db..876a9f9f28242 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -705,12 +705,14 @@ private[python] class PythonMLLibAPI extends Serializable {
lossStr: String,
numIterations: Int,
learningRate: Double,
- maxDepth: Int): GradientBoostedTreesModel = {
+ maxDepth: Int,
+ maxBins: Int): GradientBoostedTreesModel = {
val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
boostingStrategy.setLoss(Losses.fromString(lossStr))
boostingStrategy.setNumIterations(numIterations)
boostingStrategy.setLearningRate(learningRate)
boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
+ boostingStrategy.treeStrategy.setMaxBins(maxBins)
boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
new file mode 100644
index 0000000000000..ab97e3dbc6ee0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+@BeanInfo
+case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
+
+class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import org.apache.spark.ml.feature.NGramSuite._
+
+ test("default behavior yields bigram features") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("Test", "for", "ngram", "."),
+ Array("Test for", "for ngram", "ngram .")
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("NGramLength=4 yields length 4 n-grams") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(4)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("a", "b", "c", "d", "e"),
+ Array("a b c d", "b c d e")
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("empty input yields empty output") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(4)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array(),
+ Array()
+ )))
+ testNGram(nGram, dataset)
+ }
+
+ test("input array < n yields empty output") {
+ val nGram = new NGram()
+ .setInputCol("inputTokens")
+ .setOutputCol("nGrams")
+ .setN(6)
+ val dataset = sqlContext.createDataFrame(Seq(
+ NGramTestData(
+ Array("a", "b", "c", "d", "e"),
+ Array()
+ )))
+ testNGram(nGram, dataset)
+ }
+}
+
+object NGramSuite extends SparkFunSuite {
+
+ def testNGram(t: NGram, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("nGrams", "wantedNGrams")
+ .collect()
+ .foreach { case Row(actualNGrams, wantedNGrams) =>
+ assert(actualNGrams === wantedNGrams)
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index 6d4f717d4931b..80cacb5ace2d4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -156,7 +156,6 @@
2.10
${scala.version}
org.scala-lang
- 3.6.3
1.9.13
2.4.4
1.1.1.7
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 015d0296dd369..7a748fb5e38bd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -54,7 +54,17 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
// SQL execution is considered private.
- excludePackage("org.apache.spark.sql.execution")
+ excludePackage("org.apache.spark.sql.execution"),
+ // NanoTime and CatalystTimestampConverter is only used inside catalyst,
+ // not needed anymore
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.timestamp.NanoTime"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.timestamp.NanoTime$"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.CatalystTimestampConverter"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.CatalystTimestampConverter$")
)
case v if v.startsWith("1.4") =>
Seq(
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 42e41397bf4bc..758accf4b41eb 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -135,8 +135,9 @@ class LogisticRegressionModel(LinearClassificationModel):
1
>>> sameModel.predict(SparseVector(2, {0: 1.0}))
0
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except:
... pass
>>> multi_class_data = [
@@ -387,8 +388,9 @@ class SVMModel(LinearClassificationModel):
1
>>> sameModel.predict(SparseVector(2, {0: -1.0}))
0
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except:
... pass
"""
@@ -515,8 +517,9 @@ class NaiveBayesModel(Saveable, Loader):
>>> sameModel = NaiveBayesModel.load(sc, path)
>>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
True
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except OSError:
... pass
"""
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index c38229864d3b4..e6ef72942ce77 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -79,8 +79,9 @@ class KMeansModel(Saveable, Loader):
>>> sameModel = KMeansModel.load(sc, path)
>>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
True
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except OSError:
... pass
"""
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 9c4647ddfdcfd..506ca2151cce7 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -106,8 +106,9 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
0.4...
>>> sameModel.predictAll(testset).collect()
[Rating(...
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except OSError:
... pass
"""
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c4d7d3bbee02..5ddbbee4babdd 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -133,10 +133,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except:
- ... pass
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -275,8 +276,9 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except:
... pass
>>> data = [
@@ -389,8 +391,9 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except:
... pass
>>> data = [
@@ -500,8 +503,9 @@ class IsotonicRegressionModel(Saveable, Loader):
2.0
>>> sameModel.predict(5)
16.5
+ >>> from shutil import rmtree
>>> try:
- ... os.removedirs(path)
+ ... rmtree(path)
... except OSError:
... pass
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 886dbcf9aa982..bb375d08a3ad6 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -24,6 +24,7 @@
import tempfile
import array as pyarray
from time import time, sleep
+from shutil import rmtree
from numpy import array, array_equal, zeros, inf, all, random
from numpy import sum as array_sum
@@ -399,7 +400,7 @@ def test_classification(self):
self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
try:
- os.removedirs(temp_dir)
+ rmtree(temp_dir)
except OSError:
pass
@@ -463,6 +464,13 @@ def test_regression(self):
except ValueError:
self.fail()
+ # Verify that maxBins is being passed through
+ GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
+ with self.assertRaises(Exception) as cm:
+ GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
+
class StatTests(MLlibTestCase):
# SPARK-4023
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index cfcbea573fd22..372b86a7c95d9 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -299,7 +299,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
1 internal node + 2 leaf nodes. (default: 4)
:param maxBins: maximum number of bins used for splitting
features
- (default: 100)
+ (default: 32)
:param seed: Random seed for bootstrapping and choosing feature
subsets.
:return: RandomForestModel that can be used for prediction
@@ -377,7 +377,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
1 leaf node; depth 1 means 1 internal node + 2 leaf
nodes. (default: 4)
:param maxBins: maximum number of bins used for splitting
- features (default: 100)
+ features (default: 32)
:param seed: Random seed for bootstrapping and choosing feature
subsets.
:return: RandomForestModel that can be used for prediction
@@ -435,16 +435,17 @@ class GradientBoostedTrees(object):
@classmethod
def _train(cls, data, algo, categoricalFeaturesInfo,
- loss, numIterations, learningRate, maxDepth):
+ loss, numIterations, learningRate, maxDepth, maxBins):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
- loss, numIterations, learningRate, maxDepth)
+ loss, numIterations, learningRate, maxDepth, maxBins)
return GradientBoostedTreesModel(model)
@classmethod
def trainClassifier(cls, data, categoricalFeaturesInfo,
- loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
+ loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
+ maxBins=32):
"""
Method to train a gradient-boosted trees model for
classification.
@@ -467,6 +468,8 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
:param maxDepth: Maximum depth of the tree. E.g., depth 0 means
1 leaf node; depth 1 means 1 internal node + 2 leaf
nodes. (default: 3)
+ :param maxBins: maximum number of bins used for splitting
+ features (default: 32) DecisionTree requires maxBins >= max categories
:return: GradientBoostedTreesModel that can be used for
prediction
@@ -499,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
[1.0, 0.0]
"""
return cls._train(data, "classification", categoricalFeaturesInfo,
- loss, numIterations, learningRate, maxDepth)
+ loss, numIterations, learningRate, maxDepth, maxBins)
@classmethod
def trainRegressor(cls, data, categoricalFeaturesInfo,
- loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3):
+ loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
+ maxBins=32):
"""
Method to train a gradient-boosted trees model for regression.
@@ -522,6 +526,8 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
contribution of each estimator. The learning rate
should be between in the interval (0, 1].
(default: 0.1)
+ :param maxBins: maximum number of bins used for splitting
+ features (default: 32) DecisionTree requires maxBins >= max categories
:param maxDepth: Maximum depth of the tree. E.g., depth 0 means
1 leaf node; depth 1 means 1 internal node + 2 leaf
nodes. (default: 3)
@@ -556,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
[1.0, 0.0]
"""
return cls._train(data, "regression", categoricalFeaturesInfo,
- loss, numIterations, learningRate, maxDepth)
+ loss, numIterations, learningRate, maxDepth, maxBins)
def _test():
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 20c0bc93f413c..1b64be23a667e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2198,7 +2198,7 @@ def sumApprox(self, timeout, confidence=0.95):
>>> rdd = sc.parallelize(range(1000), 10)
>>> r = sum(range(1000))
- >>> (rdd.sumApprox(1000) - r) / r < 0.05
+ >>> abs(rdd.sumApprox(1000) - r) / r < 0.05
True
"""
jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
@@ -2215,7 +2215,7 @@ def meanApprox(self, timeout, confidence=0.95):
>>> rdd = sc.parallelize(range(1000), 10)
>>> r = sum(range(1000)) / 1000.0
- >>> (rdd.meanApprox(1000) - r) / r < 0.05
+ >>> abs(rdd.meanApprox(1000) - r) / r < 0.05
True
"""
jrdd = self.map(float)._to_java_object_rdd()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 599c9ac5794a2..dc239226e6d3c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
- [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
+ time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
- [Row(c0=u'4')]
+ [Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f036644acc961..1b7bc0f9a12be 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -218,7 +218,10 @@ def mode(self, saveMode):
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite = self._jwrite.mode(saveMode)
+ # At the JVM side, the default value of mode is already set to "error".
+ # So, if the given saveMode is None, we will not call JVM-side's mode method.
+ if saveMode is not None:
+ self._jwrite = self._jwrite.mode(saveMode)
return self
@since(1.4)
@@ -253,11 +256,12 @@ def partitionBy(self, *cols):
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
- self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ if len(cols) > 0:
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
return self
@since(1.4)
- def save(self, path=None, format=None, mode="error", **options):
+ def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the ``format`` and a set of ``options``.
@@ -272,11 +276,12 @@ def save(self, path=None, format=None, mode="error", **options):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
:param options: all other string options
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self.mode(mode).options(**options)
+ self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
if path is None:
@@ -296,7 +301,7 @@ def insertInto(self, tableName, overwrite=False):
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
@since(1.4)
- def saveAsTable(self, name, format=None, mode="error", **options):
+ def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
"""Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the
@@ -312,15 +317,16 @@ def saveAsTable(self, name, format=None, mode="error", **options):
:param name: the table name
:param format: the format used to save
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param partitionBy: names of partitioning columns
:param options: all other string options
"""
- self.mode(mode).options(**options)
+ self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
self._jwrite.saveAsTable(name)
@since(1.4)
- def json(self, path, mode="error"):
+ def json(self, path, mode=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -333,10 +339,10 @@ def json(self, path, mode="error"):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite.mode(mode).json(path)
+ self.mode(mode)._jwrite.json(path)
@since(1.4)
- def parquet(self, path, mode="error"):
+ def parquet(self, path, mode=None, partitionBy=()):
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -346,13 +352,15 @@ def parquet(self, path, mode="error"):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite.mode(mode).parquet(path)
+ self.partitionBy(partitionBy).mode(mode)
+ self._jwrite.parquet(path)
@since(1.4)
- def jdbc(self, url, table, mode="error", properties={}):
+ def jdbc(self, url, table, mode=None, properties={}):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
.. note:: Don't create too many partitions in parallel on a large cluster;\
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5fbb7d098820..13f4556943ac8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -539,6 +539,38 @@ def test_save_and_load(self):
shutil.rmtree(tmpPath)
+ def test_save_and_load_builder(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.json(tmpPath)
+ actual = self.sqlCtx.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").json(tmpPath)
+ actual = self.sqlCtx.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
+ .format("json").save(path=tmpPath)
+ actual =\
+ self.sqlCtx.read.format("json")\
+ .load(path=tmpPath, noUse="this options will not be used in load.")
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index f7849ebebc573..83f2a312972fb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions;
-import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.sql.catalyst.InternalRow;
@@ -142,14 +141,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
// Make sure that the buffer is large enough to hold the key. If it's not, grow it:
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
- // This new array will be initially zero, so there's no need to zero it out here
groupingKeyConversionScratchSpace = new byte[groupingKeySize];
- } else {
- // Zero out the buffer that's used to hold the current row. This is necessary in order
- // to ensure that rows hash properly, since garbage data from the previous row could
- // otherwise end up as padding in this row. As a performance optimization, we only zero out
- // the portion of the buffer that we'll actually write to.
- Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte) 0);
}
final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
groupingKey,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index ed04d2e50ec84..bb2f2079b40f0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -47,7 +47,8 @@
* In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
- * base address of the row) that points to the beginning of the variable-length field.
+ * base address of the row) that points to the beginning of the variable-length field, and length
+ * (they are combined into a long).
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
@@ -92,6 +93,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
*/
public static final Set readableFieldTypes;
+ // TODO: support DecimalType
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet(
@@ -111,7 +113,8 @@ public static int calculateBitSetWidthInBytes(int numFields) {
// We support get() on a superset of the types for which we support set():
final Set _readableFieldTypes = new HashSet(
Arrays.asList(new DataType[]{
- StringType
+ StringType,
+ BinaryType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -221,11 +224,6 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
- @Override
- public void setString(int ordinal, String value) {
- throw new UnsupportedOperationException();
- }
-
@Override
public int size() {
return numFields;
@@ -249,6 +247,8 @@ public Object get(int i) {
return null;
} else if (dataType == StringType) {
return getUTF8String(i);
+ } else if (dataType == BinaryType) {
+ return getBinary(i);
} else {
throw new UnsupportedOperationException();
}
@@ -311,19 +311,23 @@ public double getDouble(int i) {
}
public UTF8String getUTF8String(int i) {
+ return UTF8String.fromBytes(getBinary(i));
+ }
+
+ public byte[] getBinary(int i) {
assertIndexIsValid(i);
- final long offsetToStringSize = getLong(i);
- final int stringSizeInBytes =
- (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
- final byte[] strBytes = new byte[stringSizeInBytes];
+ final long offsetAndSize = getLong(i);
+ final int offset = (int)(offsetAndSize >> 32);
+ final int size = (int)(offsetAndSize & ((1L << 32) - 1));
+ final byte[] bytes = new byte[size];
PlatformDependent.copyMemory(
baseObject,
- baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
- strBytes,
+ baseOffset + offset,
+ bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
- stringSizeInBytes
+ size
);
- return UTF8String.fromBytes(strBytes);
+ return bytes;
}
@Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 620e8de83a96c..429fc4077be9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -19,15 +19,15 @@ package org.apache.spark.sql.catalyst
import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
-import java.sql.{Timestamp, Date}
+import java.sql.{Date, Timestamp}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
import scala.collection.mutable.HashMap
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -272,18 +272,18 @@ object CatalystTypeConverters {
}
private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
- override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue)
+ override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
override def toScala(catalystValue: Any): Date =
- if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int])
+ if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
}
private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
override def toCatalystImpl(scalaValue: Timestamp): Long =
- DateUtils.fromJavaTimestamp(scalaValue)
+ DateTimeUtils.fromJavaTimestamp(scalaValue)
override def toScala(catalystValue: Any): Timestamp =
if (catalystValue == null) null
- else DateUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
+ else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
toScala(row.getLong(column))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index da3a717f90058..79f526e823cd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")
- protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"c$i")()
- }
- }
-
protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte
@@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
- .map(Aggregate(_, assignAliases(p), withFilter))
- .getOrElse(Project(assignAliases(p), withFilter))
+ .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
+ .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 21b05760256b4..6311784422a91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
@@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
+ ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
- TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
@@ -132,12 +130,38 @@ class Analyzer(
}
/**
- * Removes no-op Alias expressions from the plan.
+ * Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
- object TrimGroupingAliases extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Aggregate(groups, aggs, child) =>
- Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
+ object ResolveAliases extends Rule[LogicalPlan] {
+ private def assignAliases(exprs: Seq[NamedExpression]) = {
+ // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
+ // to transform down the whole tree.
+ exprs.zipWithIndex.map {
+ case (u @ UnresolvedAlias(child), i) =>
+ child match {
+ case _: UnresolvedAttribute => u
+ case ne: NamedExpression => ne
+ case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
+ case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
+ case e if !e.resolved => u
+ case other => Alias(other, s"_c$i")()
+ }
+ case (other, _) => other
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case Aggregate(groups, aggs, child)
+ if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Aggregate(groups, assignAliases(aggs), child)
+
+ case g: GroupingAnalytics
+ if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ g.withNewAggs(assignAliases(g.aggregations))
+
+ case Project(projectList, child)
+ if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Project(assignAliases(projectList), child)
}
}
@@ -228,7 +252,7 @@ class Analyzer(
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
+ case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
@@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
- case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(child = f.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateArray(args), name) if containsStar(args) =>
+ UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
@@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
+ withPosition(u) {
+ q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+ }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
+ private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
+ case UnresolvedAlias(child) => child
+ case other => other
+ }
+
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
@@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).getOrElse(u)
+ plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
- case Alias(g: Generator, name)
- if g.resolved &&
- g.elementTypes.size > 1 &&
- java.util.regex.Pattern.matches("_c[0-9]+", name) => {
- // Assume the default name given by parser is "_c[0-9]+",
- // TODO in long term, move the naming logic from Parser to Analyzer.
- // In projection, Parser gave default name for TGF as does for normal UDF,
- // but the TGF probably have multiple output columns/names.
- // e.g. SELECT explode(map(key, value)) FROM src;
- // Let's simply ignore the default given name for this case.
- Some((g, Nil))
- }
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7fabd2bfc80ab..c5a1437be6d05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- val cleaned = aggregateExprs.map(_.transform {
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- case Alias(g, _) => g
- })
-
- cleaned.foreach(checkValidAggregateExpression)
+ aggregateExprs.foreach(checkValidAggregateExpression)
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index c9d91425788a8..ae3adbab05108 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
@@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override def toString: String = s"$child[$extraction]"
}
+
+/**
+ * Holds the expression that has yet to be aliased.
+ */
+case class UnresolvedAlias(child: Expression) extends NamedExpression
+ with trees.UnaryNode[Expression] {
+
+ override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def name: String = throw new UnresolvedException(this, "name")
+
+ override lazy val resolved = false
+
+ override def eval(input: InternalRow = null): Any =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index ad920f287820c..d271434a306dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -24,7 +24,7 @@ import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -115,9 +115,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
- case DateType => buildCast[Int](_, d => UTF8String.fromString(DateUtils.toString(d)))
+ case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.toString(d)))
case TimestampType => buildCast[Long](_,
- t => UTF8String.fromString(timestampToString(DateUtils.toJavaTimestamp(t))))
+ t => UTF8String.fromString(timestampToString(DateTimeUtils.toJavaTimestamp(t))))
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}
@@ -162,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (periodIdx != -1 && n.length() - periodIdx > 9) {
n = n.substring(0, periodIdx + 10)
}
- try DateUtils.fromJavaTimestamp(Timestamp.valueOf(n))
+ try DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(n))
catch { case _: java.lang.IllegalArgumentException => null }
})
case BooleanType =>
@@ -176,7 +176,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case ByteType =>
buildCast[Byte](_, b => longToTimestamp(b.toLong))
case DateType =>
- buildCast[Int](_, d => DateUtils.toMillisSinceEpoch(d) * 10000)
+ buildCast[Int](_, d => DateTimeUtils.toMillisSinceEpoch(d) * 10000)
// TimestampWritable.decimalToTimestamp
case DecimalType() =>
buildCast[Decimal](_, d => decimalToTimestamp(d))
@@ -225,13 +225,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s =>
- try DateUtils.fromJavaDate(Date.valueOf(s.toString))
+ try DateTimeUtils.fromJavaDate(Date.valueOf(s.toString))
catch { case _: java.lang.IllegalArgumentException => null }
)
case TimestampType =>
// throw valid precision more than seconds, according to Hive.
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
- buildCast[Long](_, t => DateUtils.millisToDays(t / 10000L))
+ buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 10000L))
// Hive throws this exception as a Semantic Exception
// It is never possible to compare result when hive return with exception,
// so we can return null
@@ -442,7 +442,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""${ctx.stringType}.fromString(
- org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
+ org.apache.spark.sql.catalyst.util.DateTimeUtils.toString($c))""")
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4aaabff15b6ee..4d6c1c265150d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.{catalyst, AnalysisException}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {
(child.dataType, extraction) match {
- case (StructType(fields), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetStructField(child, fields(ordinal), ordinal)
- case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
+ case (StructType(fields), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+
+ case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
+
case (_: MapType, _) =>
GetMapValue(child, extraction)
+
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
@@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}
+abstract class ExtractValueWithStruct extends ExtractValue {
+ self: Product =>
+
+ def field: StructField
+ override def toString: String = s"$child.${field.name}"
+}
+
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
- extends ExtractValue {
+ extends ExtractValueWithStruct {
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
@@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
- containsNull: Boolean) extends ExtractValue {
+ containsNull: Boolean) extends ExtractValueWithStruct {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
@@ -178,7 +186,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
- val index = ordinal.asInstanceOf[Int]
+ val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.size || index < 0) {
null
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 72f740ecaead3..89adaf053b1a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -72,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
*/
def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+
+ if (writers.length > 0) {
+ // zero-out the bitset
+ var n = writers.length / 64
+ while (n >= 0) {
+ PlatformDependent.UNSAFE.putLong(
+ unsafeRow.getBaseObject,
+ unsafeRow.getBaseOffset + n * 8,
+ 0L)
+ n -= 1
+ }
+ }
+
var fieldNumber = 0
var appendCursor: Int = fixedLengthSize
while (fieldNumber < writers.length) {
@@ -122,6 +133,7 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
+ case BinaryType => BinaryUnsafeColumnWriter
case DateType => IntUnsafeColumnWriter
case TimestampType => LongUnsafeColumnWriter
case t =>
@@ -141,6 +153,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
+private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
@@ -235,10 +248,13 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
}
}
-private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
+private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
+
+ def getBytes(source: InternalRow, column: Int): Array[Byte]
+
def getSize(source: InternalRow, column: Int): Int = {
- val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
- 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ val numBytes = getBytes(source, column).length
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
override def write(
@@ -246,19 +262,33 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
- val value = source.get(column).asInstanceOf[UTF8String]
- val baseObject = target.getBaseObject
- val baseOffset = target.getBaseOffset
- val numBytes = value.getBytes.length
- PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
+ val offset = target.getBaseOffset + appendCursor
+ val bytes = getBytes(source, column)
+ val numBytes = bytes.length
+ if ((numBytes & 0x07) > 0) {
+ // zero-out the padding bytes
+ PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L)
+ }
PlatformDependent.copyMemory(
- value.getBytes,
+ bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
- baseObject,
- baseOffset + appendCursor + 8,
+ target.getBaseObject,
+ offset,
numBytes
)
- target.setLong(column, appendCursor)
- 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong)
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ }
+}
+
+private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+ def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+ source.getAs[UTF8String](column).getBytes
+ }
+}
+
+private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+ def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+ source.getAs[Array[Byte]](column)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
similarity index 98%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 72fdcebb4cbc8..e0bf07ed182f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 6c86a47ba200c..479224af5627a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -39,8 +39,8 @@ object Literal {
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: Decimal => Literal(d, DecimalType.Unlimited)
- case t: Timestamp => Literal(DateUtils.fromJavaTimestamp(t), TimestampType)
- case d: Date => Literal(DateUtils.fromJavaDate(d), DateType)
+ case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
+ case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b6f8bfd9ff9b..179a348d5baac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression =>
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
- case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
+ case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index a853e27c1212d..b009a200b920f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and aliases it with the last part of the identifier.
+ // and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
- // the final expression as "c".
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
+ // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
+ Some(UnresolvedAlias(fieldExprs))
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 963c7820914f3..7814e51628db6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -130,6 +130,14 @@ case class Join(
}
}
+/**
+ * A hint for the optimizer that we should broadcast the `child` if used in a join operator.
+ */
+case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+}
+
+
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output
}
@@ -242,6 +250,8 @@ trait GroupingAnalytics extends UnaryNode {
def aggregations: Seq[NamedExpression]
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
/**
@@ -266,7 +276,11 @@ case class GroupingSets(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -284,7 +298,11 @@ case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -303,7 +321,11 @@ case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
similarity index 68%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 5cadc141af1df..ff79884a44d00 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -17,18 +17,28 @@
package org.apache.spark.sql.catalyst.util
-import java.sql.{Timestamp, Date}
+import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.{Calendar, TimeZone}
import org.apache.spark.sql.catalyst.expressions.Cast
/**
- * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date
+ * Helper functions for converting between internal and external date and time representations.
+ * Dates are exposed externally as java.sql.Date and are represented internally as the number of
+ * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp
+ * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond
+ * precision.
*/
-object DateUtils {
- private val MILLIS_PER_DAY = 86400000
- private val HUNDRED_NANOS_PER_SECOND = 10000000L
+object DateTimeUtils {
+ final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L
+
+ // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian
+ final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5
+ final val SECONDS_PER_DAY = 60 * 60 * 24L
+ final val HUNDRED_NANOS_PER_SECOND = 1000L * 1000L * 10L
+ final val NANOS_PER_SECOND = HUNDRED_NANOS_PER_SECOND * 100
+
// Java TimeZone has no mention of thread safety. Use thread local instance to be safe.
private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] {
@@ -117,4 +127,25 @@ object DateUtils {
0L
}
}
+
+ /**
+ * Return the number of 100ns (hundred of nanoseconds) since epoch from Julian day
+ * and nanoseconds in a day
+ */
+ def fromJulianDay(day: Int, nanoseconds: Long): Long = {
+ // use Long to avoid rounding errors
+ val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2
+ seconds * HUNDRED_NANOS_PER_SECOND + nanoseconds / 100L
+ }
+
+ /**
+ * Return Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds)
+ */
+ def toJulianDay(num100ns: Long): (Int, Long) = {
+ val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2
+ val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH
+ val secondsInDay = seconds % SECONDS_PER_DAY
+ val nanos = (num100ns % HUNDRED_NANOS_PER_SECOND) * 100L
+ (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index a85af9e04aedb..bd9823bc05424 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.types
+import java.math.{MathContext, RoundingMode}
+
import org.apache.spark.annotation.DeveloperApi
/**
@@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toBigDecimal: BigDecimal = {
if (decimalVal.ne(null)) {
- decimalVal
+ decimalVal(MathContext.UNLIMITED)
} else {
- BigDecimal(longVal, _scale)
+ BigDecimal(longVal, _scale)(MathContext.UNLIMITED)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e407f6f166e86..f3809be722a84 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Timestamp, Date}
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
/**
@@ -156,7 +156,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
checkEvaluation(cast(cast(d, StringType), DateType), 0)
checkEvaluation(cast(cast(nts, TimestampType), StringType), nts)
- checkEvaluation(cast(cast(ts, StringType), TimestampType), DateUtils.fromJavaTimestamp(ts))
+ checkEvaluation(cast(cast(ts, StringType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts))
// all convert to string type to check
checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd)
@@ -301,9 +301,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(ts, LongType), 15.toLong)
checkEvaluation(cast(ts, FloatType), 15.002f)
checkEvaluation(cast(ts, DoubleType), 15.002)
- checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateUtils.fromJavaTimestamp(ts))
- checkEvaluation(cast(cast(tss, IntegerType), TimestampType), DateUtils.fromJavaTimestamp(ts))
- checkEvaluation(cast(cast(tss, LongType), TimestampType), DateUtils.fromJavaTimestamp(ts))
+ checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts))
+ checkEvaluation(cast(cast(tss, IntegerType), TimestampType),
+ DateTimeUtils.fromJavaTimestamp(ts))
+ checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts))
checkEvaluation(
cast(cast(millis.toFloat / 1000, TimestampType), FloatType),
millis.toFloat / 1000)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 2b0f4618b21e0..b80911e7257fc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -26,6 +26,26 @@ import org.apache.spark.unsafe.types.UTF8String
class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
+ /**
+ * Runs through the testFunc for all integral data types.
+ *
+ * @param testFunc a test function that accepts a conversion function to convert an integer
+ * into another data type.
+ */
+ private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = {
+ testFunc(_.toByte)
+ testFunc(_.toShort)
+ testFunc(identity)
+ testFunc(_.toLong)
+ }
+
+ test("GetArrayItem") {
+ testIntegralDataTypes { convert =>
+ val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
+ checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b")
+ }
+ }
+
test("CreateStruct") {
val row = InternalRow(1, 2, 3)
val c1 = 'a.int.at(0).as("a")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index b6261bfba0786..72fec3b86e5e4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{IntegerType, BooleanType}
@@ -167,8 +167,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row)
checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row)
- val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01"))
- val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02"))
+ val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))
+ val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02"))
checkEvaluation(Literal(d1) < Literal(d2), true)
val ts1 = new Timestamp(12)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 721ef8a22608c..c0675f4f4dff6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -23,8 +23,8 @@ import java.util.Arrays
import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
-import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getInt(2) should be (2)
}
- test("basic conversion with primitive and string types") {
- val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
+ test("basic conversion with primitive, string and binary types") {
+ val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
val converter = new UnsafeRowConverter(fieldTypes)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
- row.setString(2, "World")
+ row.update(2, "World".getBytes)
val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 3) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
@@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
- unsafeRow.getString(2) should be ("World")
+ unsafeRow.getBinary(2) should be ("World".getBytes)
}
test("basic conversion with primitive, string, date and timestamp types") {
@@ -83,12 +83,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
- row.update(2, DateUtils.fromJavaDate(Date.valueOf("1970-01-01")))
- row.update(3, DateUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
+ row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")))
+ row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 4) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
@@ -98,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
// Date is represented as Int in unsafeRow
- DateUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
+ DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
// Timestamp is represented as Long in unsafeRow
- DateUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
+ DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
(Timestamp.valueOf("2015-05-08 08:10:25"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
similarity index 52%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 4d8fe4ac5e78f..03eb64f097a37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -21,19 +21,31 @@ import java.sql.Timestamp
import org.apache.spark.SparkFunSuite
-class DateUtilsSuite extends SparkFunSuite {
+class DateTimeUtilsSuite extends SparkFunSuite {
- test("timestamp") {
+ test("timestamp and 100ns") {
val now = new Timestamp(System.currentTimeMillis())
now.setNanos(100)
- val ns = DateUtils.fromJavaTimestamp(now)
- assert(ns % 10000000L == 1)
- assert(DateUtils.toJavaTimestamp(ns) == now)
+ val ns = DateTimeUtils.fromJavaTimestamp(now)
+ assert(ns % 10000000L === 1)
+ assert(DateTimeUtils.toJavaTimestamp(ns) === now)
List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t =>
- val ts = DateUtils.toJavaTimestamp(t)
- assert(DateUtils.fromJavaTimestamp(ts) == t)
- assert(DateUtils.toJavaTimestamp(DateUtils.fromJavaTimestamp(ts)) == ts)
+ val ts = DateTimeUtils.toJavaTimestamp(t)
+ assert(DateTimeUtils.fromJavaTimestamp(ts) === t)
+ assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts)
}
}
+
+ test("100ns and julian day") {
+ val (d, ns) = DateTimeUtils.toJulianDay(0)
+ assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH)
+ assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND)
+ assert(DateTimeUtils.fromJulianDay(d, ns) == 0L)
+
+ val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100)
+ val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t))
+ val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1))
+ assert(t.equals(t2))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 4c0365cf1b6f9..ccc29c0dc8c35 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -162,4 +162,9 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}
+
+ test("accurate precision after multiplication") {
+ val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
+ assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index ed75475a87067..8fc16928adbd9 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -73,11 +73,6 @@
jackson-databind
${fasterxml.jackson.version}
-
- org.jodd
- jodd-core
- ${jodd.version}
-
junit
junit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b4e008a6e8480..f201c8ea8a110 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 466258e76f9f6..492a3321bc0bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -629,6 +629,10 @@ class DataFrame private[sql](
@scala.annotation.varargs
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
case Column(expr: NamedExpression) => expr
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
// correct defaults after the nested expression's type has been resolved.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 45b3e1bc627d5..99d557b03a033 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
@@ -70,27 +70,31 @@ class GroupedData protected[sql](
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {
- private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
- val retainedExprs = groupingExprs.map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
- retainedExprs ++ aggExprs
- } else {
- aggExprs
- }
+ groupingExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+ val aliasedAgg = aggregates.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
groupType match {
case GroupedData.GroupByType =>
DataFrame(
- df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
+ df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
- df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
- df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
}
}
@@ -112,10 +116,7 @@ class GroupedData protected[sql](
namedExpr
}
}
- toDF(columnExprs.map { c =>
- val a = f(c)
- Alias(a, a.prettyString)()
- })
+ toDF(columnExprs.map(f))
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -169,8 +170,7 @@ class GroupedData protected[sql](
*/
def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
- val a = strToExpr(expr)(df(colName).expr)
- Alias(a, a.prettyString)()
+ strToExpr(expr)(df(colName).expr)
}.toSeq)
}
@@ -224,10 +224,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
- toDF((expr +: exprs).map(_.expr).map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- })
+ toDF((expr +: exprs).map(_.expr))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 422992d019c7b..5c420eb9d761f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
@@ -52,6 +52,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Matches a plan whose output should be small enough to be used in broadcast join.
+ */
+ object CanBroadcast {
+ def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
+ case BroadcastHint(p) => Some(p)
+ case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
+ p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p)
+ case _ => None
+ }
+ }
+
/**
* Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
* evaluated by matching hash keys.
@@ -80,15 +92,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
- right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
- left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
+ makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
// for now let's support inner join first, then add outer join
@@ -329,6 +337,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
+ case BroadcastHint(child) => apply(child)
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 1ce150ceaf5f9..6db551c543a9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
- case plan: LogicalPlan =>
+ case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
@@ -148,8 +148,8 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
- case (date: Int, DateType) => DateUtils.toJavaDate(date)
- case (t: Long, TimestampType) => DateUtils.toJavaTimestamp(t)
+ case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
+ case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
case (s: UTF8String, StringType) => s.toString
// Pyrolite can handle Timestamp and Decimal
@@ -188,12 +188,12 @@ object EvaluatePython {
}): Row
case (c: java.util.Calendar, DateType) =>
- DateUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
+ DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
case (c: java.util.Calendar, TimestampType) =>
c.getTimeInMillis * 10000L
case (t: java.sql.Timestamp, TimestampType) =>
- DateUtils.fromJavaTimestamp(t)
+ DateTimeUtils.fromJavaTimestamp(t)
case (_, udt: UserDefinedType[_]) =>
fromJava(obj, udt.sqlType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 7e7a099a8318b..38d9085a505fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -565,6 +566,22 @@ object functions {
array((colName +: colNames).map(col) : _*)
}
+ /**
+ * Marks a DataFrame as small enough for use in broadcast joins.
+ *
+ * The following example marks the right DataFrame for broadcast hash join using `joinKey`.
+ * {{{
+ * // left and right are DataFrames
+ * left.join(broadcast(right), "joinKey")
+ * }}}
+ *
+ * @group normal_funcs
+ * @since 1.5.0
+ */
+ def broadcast(df: DataFrame): DataFrame = {
+ DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan))
+ }
+
/**
* Returns the first column that is not null.
* {{{
@@ -1448,7 +1465,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
ScalaUdf(f, returnType, Seq($argsInUdf))
}""")
@@ -1584,7 +1603,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function0[_], returnType: DataType): Column = {
ScalaUdf(f, returnType, Seq())
}
@@ -1595,7 +1616,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr))
}
@@ -1606,7 +1629,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
}
@@ -1617,7 +1642,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
}
@@ -1628,7 +1655,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
}
@@ -1639,7 +1668,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
}
@@ -1650,7 +1681,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
}
@@ -1661,7 +1694,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
}
@@ -1672,7 +1707,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
}
@@ -1683,7 +1720,9 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
}
@@ -1694,13 +1733,34 @@ object functions {
*
* @group udf_funcs
* @since 1.3.0
+ * @deprecated As of 1.5.0, since it's redundant with udf()
*/
+ @deprecated("Use udf", "1.5.0")
def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
}
// scalastyle:on
+ /**
+ * Call an user-defined function.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+ * val sqlContext = df.sqlContext
+ * sqlContext.udf.register("simpleUdf", (v: Int) => v * v)
+ * df.select($"id", callUDF("simpleUdf", $"value"))
+ * }}}
+ *
+ * @group udf_funcs
+ * @since 1.5.0
+ */
+ def callUDF(udfName: String, cols: Column*): Column = {
+ UnresolvedFunction(udfName, cols.map(_.expr))
+ }
+
/**
* Call an user-defined function.
* Example:
@@ -1715,7 +1775,9 @@ object functions {
*
* @group udf_funcs
* @since 1.4.0
+ * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF
*/
+ @deprecated("Use callUDF", "1.5.0")
def callUdf(udfName: String, cols: Column*): Column = {
UnresolvedFunction(udfName, cols.map(_.expr))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 226b143923df6..8b4276b2c364c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -22,13 +22,13 @@ import java.util.Properties
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{InternalRow, SpecificMutableRow}
-import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
/**
* Data corresponding to one partition of a JDBCRDD.
@@ -383,10 +383,10 @@ private[sql] class JDBCRDD(
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
case DateConversion =>
- // DateUtils.fromJavaDate does not handle null value, so we need to check it.
+ // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
- mutableRow.setInt(i, DateUtils.fromJavaDate(dateVal))
+ mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
} else {
mutableRow.update(i, null)
}
@@ -421,7 +421,7 @@ private[sql] class JDBCRDD(
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {
- mutableRow.setLong(i, DateUtils.fromJavaTimestamp(t))
+ mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
} else {
mutableRow.update(i, null)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
index 817e8a20b34de..6222addc9aa3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
@@ -25,7 +25,7 @@ import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -63,10 +63,10 @@ private[sql] object JacksonParser {
null
case (VALUE_STRING, DateType) =>
- DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime)
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime)
case (VALUE_STRING, TimestampType) =>
- DateUtils.stringToTime(parser.getText).getTime * 10000L
+ DateTimeUtils.stringToTime(parser.getText).getTime * 10000L
case (VALUE_NUMBER_INT, TimestampType) =>
parser.getLongValue * 10000L
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 44594c5080ff4..73d9520d6f53f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -393,8 +393,8 @@ private[sql] object JsonRDD extends Logging {
value match {
// only support string as date
case value: java.lang.String =>
- DateUtils.millisToDays(DateUtils.stringToTime(value).getTime)
- case value: java.sql.Date => DateUtils.fromJavaDate(value)
+ DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(value).getTime)
+ case value: java.sql.Date => DateTimeUtils.fromJavaDate(value)
}
}
@@ -402,7 +402,7 @@ private[sql] object JsonRDD extends Logging {
value match {
case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L
case value: java.lang.Long => value * 10000L
- case value: java.lang.String => DateUtils.stringToTime(value).getTime * 10000L
+ case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 10000L
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 4da5e96b82e3d..cf7aa44e4cd55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -17,21 +17,19 @@
package org.apache.spark.sql.parquet
-import java.sql.Timestamp
-import java.util.{TimeZone, Calendar}
+import java.nio.ByteOrder
-import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
+import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap}
-import jodd.datetime.JDateTime
+import org.apache.parquet.Preconditions
import org.apache.parquet.column.Dictionary
-import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
+import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
import org.apache.parquet.schema.MessageType
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
import org.apache.spark.sql.types._
-import org.apache.spark.sql.parquet.timestamp.NanoTime
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -269,7 +267,12 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
* Read a Timestamp value from a Parquet Int96Value
*/
protected[parquet] def readTimestamp(value: Binary): Long = {
- DateUtils.fromJavaTimestamp(CatalystTimestampConverter.convertToTimestamp(value))
+ Preconditions.checkArgument(value.length() == 12, "Must be 12 bytes")
+ val buf = value.toByteBuffer
+ buf.order(ByteOrder.LITTLE_ENDIAN)
+ val timeOfDayNanos = buf.getLong
+ val julianDay = buf.getInt
+ DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)
}
}
@@ -498,73 +501,6 @@ private[parquet] object CatalystArrayConverter {
val INITIAL_ARRAY_SIZE = 20
}
-private[parquet] object CatalystTimestampConverter {
- // TODO most part of this comes from Hive-0.14
- // Hive code might have some issues, so we need to keep an eye on it.
- // Also we use NanoTime and Int96Values from parquet-examples.
- // We utilize jodd to convert between NanoTime and Timestamp
- val parquetTsCalendar = new ThreadLocal[Calendar]
- def getCalendar: Calendar = {
- // this is a cache for the calendar instance.
- if (parquetTsCalendar.get == null) {
- parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT")))
- }
- parquetTsCalendar.get
- }
- val NANOS_PER_SECOND: Long = 1000000000
- val SECONDS_PER_MINUTE: Long = 60
- val MINUTES_PER_HOUR: Long = 60
- val NANOS_PER_MILLI: Long = 1000000
-
- def convertToTimestamp(value: Binary): Timestamp = {
- val nt = NanoTime.fromBinary(value)
- val timeOfDayNanos = nt.getTimeOfDayNanos
- val julianDay = nt.getJulianDay
- val jDateTime = new JDateTime(julianDay.toDouble)
- val calendar = getCalendar
- calendar.set(Calendar.YEAR, jDateTime.getYear)
- calendar.set(Calendar.MONTH, jDateTime.getMonth - 1)
- calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay)
-
- // written in command style
- var remainder = timeOfDayNanos
- calendar.set(
- Calendar.HOUR_OF_DAY,
- (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt)
- remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)
- calendar.set(
- Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt)
- remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE)
- calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt)
- val nanos = remainder % NANOS_PER_SECOND
- val ts = new Timestamp(calendar.getTimeInMillis)
- ts.setNanos(nanos.toInt)
- ts
- }
-
- def convertFromTimestamp(ts: Timestamp): Binary = {
- val calendar = getCalendar
- calendar.setTime(ts)
- val jDateTime = new JDateTime(calendar.get(Calendar.YEAR),
- calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH))
- // Hive-0.14 didn't set hour before get day number, while the day number should
- // has something to do with hour, since julian day number grows at 12h GMT
- // here we just follow what hive does.
- val julianDay = jDateTime.getJulianDayNumber
-
- val hour = calendar.get(Calendar.HOUR_OF_DAY)
- val minute = calendar.get(Calendar.MINUTE)
- val second = calendar.get(Calendar.SECOND)
- val nanos = ts.getNanos
- // Hive-0.14 would use hours directly, that might be wrong, since the day starts
- // from 12h in Julian. here we just follow what hive does.
- val nanosOfDay = nanos + second * NANOS_PER_SECOND +
- minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE +
- hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR
- NanoTime(julianDay, nanosOfDay).toBinary
- }
-}
-
/**
* A `parquet.io.api.GroupConverter` that converts a single-element groups that
* match the characteristics of an array (see
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index a8775a2a8fd83..e65fa0030e179 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet
+import java.nio.{ByteOrder, ByteBuffer}
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
@@ -29,7 +30,7 @@ import org.apache.parquet.schema.MessageType
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -298,7 +299,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
}
// Scratch array used to write decimals as fixed-length binary
- private val scratchBytes = new Array[Byte](8)
+ private[this] val scratchBytes = new Array[Byte](8)
private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
@@ -313,10 +314,16 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
}
+ // array used to write Timestamp as Int96 (fixed-length binary)
+ private[this] val int96buf = new Array[Byte](12)
+
private[parquet] def writeTimestamp(ts: Long): Unit = {
- val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(
- DateUtils.toJavaTimestamp(ts))
- writer.addBinary(binaryNanoTime)
+ val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts)
+ val buf = ByteBuffer.wrap(int96buf)
+ buf.order(ByteOrder.LITTLE_ENDIAN)
+ buf.putLong(timeOfDayNanos)
+ buf.putInt(julianDay)
+ writer.addBinary(Binary.fromByteArray(int96buf))
}
}
@@ -360,7 +367,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
case DateType => writer.addInteger(record.getInt(index))
- case TimestampType => writeTimestamp(record(index).asInstanceOf[Long])
+ case TimestampType => writeTimestamp(record.getLong(index))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index c9de45e0ddfbb..e049d54bf55dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.{Logging, SparkException, Partition => SparkPartition}
+import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
private[sql] class DefaultSource extends HadoopFsRelationProvider {
override def createRelation(
@@ -60,50 +60,21 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
extends OutputWriter {
private val recordWriter: RecordWriter[Void, InternalRow] = {
- val conf = context.getConfiguration
val outputFormat = {
- // When appending new Parquet files to an existing Parquet file directory, to avoid
- // overwriting existing data files, we need to find out the max task ID encoded in these data
- // file names.
- // TODO Make this snippet a utility function for other data source developers
- val maxExistingTaskId = {
- // Note that `path` may point to a temporary location. Here we retrieve the real
- // destination path from the configuration
- val outputPath = new Path(conf.get("spark.sql.sources.output.path"))
- val fs = outputPath.getFileSystem(conf)
-
- if (fs.exists(outputPath)) {
- // Pattern used to match task ID in part file names, e.g.:
- //
- // part-r-00001.gz.parquet
- // ^~~~~
- val partFilePattern = """part-.-(\d{1,}).*""".r
-
- fs.listStatus(outputPath).map(_.getPath.getName).map {
- case partFilePattern(id) => id.toInt
- case name if name.startsWith("_") => 0
- case name if name.startsWith(".") => 0
- case name => throw new AnalysisException(
- s"Trying to write Parquet files to directory $outputPath, " +
- s"but found items with illegal name '$name'.")
- }.reduceOption(_ max _).getOrElse(0)
- } else {
- 0
- }
- }
-
new ParquetOutputFormat[InternalRow]() {
// Here we override `getDefaultWorkFile` for two reasons:
//
- // 1. To allow appending. We need to generate output file name based on the max available
- // task ID computed above.
+ // 1. To allow appending. We need to generate unique output file names to avoid
+ // overwriting existing files (either exist before the write job, or are just written
+ // by other tasks within the same write job).
//
// 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses
// `FileOutputCommitter.getWorkPath()`, which points to the base directory of all
// partitions in the case of dynamic partitioning.
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1
- new Path(path, f"part-r-$split%05d$extension")
+ val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
+ val split = context.getTaskAttemptID.getTaskID.getId
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
deleted file mode 100644
index 4d5ed211ad0c0..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.parquet.timestamp
-
-import java.nio.{ByteBuffer, ByteOrder}
-
-import org.apache.parquet.Preconditions
-import org.apache.parquet.io.api.{Binary, RecordConsumer}
-
-private[parquet] class NanoTime extends Serializable {
- private var julianDay = 0
- private var timeOfDayNanos = 0L
-
- def set(julianDay: Int, timeOfDayNanos: Long): this.type = {
- this.julianDay = julianDay
- this.timeOfDayNanos = timeOfDayNanos
- this
- }
-
- def getJulianDay: Int = julianDay
-
- def getTimeOfDayNanos: Long = timeOfDayNanos
-
- def toBinary: Binary = {
- val buf = ByteBuffer.allocate(12)
- buf.order(ByteOrder.LITTLE_ENDIAN)
- buf.putLong(timeOfDayNanos)
- buf.putInt(julianDay)
- buf.flip()
- Binary.fromByteBuffer(buf)
- }
-
- def writeValue(recordConsumer: RecordConsumer): Unit = {
- recordConsumer.addBinary(toBinary)
- }
-
- override def toString: String =
- "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}"
-}
-
-private[sql] object NanoTime {
- def fromBinary(bytes: Binary): NanoTime = {
- Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes")
- val buf = bytes.toByteBuffer
- buf.order(ByteOrder.LITTLE_ENDIAN)
- val timeOfDayNanos = buf.getLong
- val julianDay = buf.getInt
- new NanoTime().set(julianDay, timeOfDayNanos)
- }
-
- def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = {
- new NanoTime().set(julianDay, timeOfDayNanos)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index c16bd9ae52c81..215e53c020849 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.sources
-import java.util.Date
+import java.util.{Date, UUID}
import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, FileOutputCommitter => MapReduceFileOutputCommitter}
-import org.apache.parquet.hadoop.util.ContextUtil
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat}
import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
@@ -59,6 +58,28 @@ private[sql] case class InsertIntoDataSource(
}
}
+/**
+ * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
+ * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a
+ * single write job, and owns a UUID that identifies this job. Each concrete implementation of
+ * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for
+ * each task output file. This UUID is passed to executor side via a property named
+ * `spark.sql.sources.writeJobUUID`.
+ *
+ * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]]
+ * are used to write to normal tables and tables with dynamic partitions.
+ *
+ * Basic work flow of this command is:
+ *
+ * 1. Driver side setup, including output committer initialization and data source specific
+ * preparation work for the write job to be issued.
+ * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
+ * rows within an RDD partition.
+ * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
+ * exception is thrown during task commitment, also aborts that task.
+ * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
+ * thrown during job commitment, also aborts the job.
+ */
private[sql] case class InsertIntoHadoopFsRelation(
@transient relation: HadoopFsRelation,
@transient query: LogicalPlan,
@@ -261,7 +282,14 @@ private[sql] abstract class BaseWriterContainer(
with Logging
with Serializable {
- protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job))
+ protected val serializableConf = new SerializableConfiguration(job.getConfiguration)
+
+ // This UUID is used to avoid output file name collision between different appending write jobs.
+ // These jobs may belong to different SparkContext instances. Concrete data source implementations
+ // may use this UUID to generate unique file names (e.g., `part-r--.parquet`).
+ // The reason why this ID is used to identify a job rather than a single task output file is
+ // that, speculative tasks must generate the same output file name as the original task.
+ private val uniqueWriteJobId = UUID.randomUUID()
// This is only used on driver side.
@transient private val jobContext: JobContext = job
@@ -290,6 +318,11 @@ private[sql] abstract class BaseWriterContainer(
setupIDs(0, 0, 0)
setupConf()
+ // This UUID is sent to executor side together with the serialized `Configuration` object within
+ // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
+ // unique task output files.
+ job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString)
+
// Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
// clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
// configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext.
@@ -417,15 +450,16 @@ private[sql] class DefaultWriterContainer(
assert(writer != null, "OutputWriter instance should have been initialized")
writer.close()
super.commitTask()
- } catch {
- case cause: Throwable =>
- super.abortTask()
- throw new RuntimeException("Failed to commit task", cause)
+ } catch { case cause: Throwable =>
+ // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will
+ // cause `abortTask()` to be invoked.
+ throw new RuntimeException("Failed to commit task", cause)
}
}
override def abortTask(): Unit = {
try {
+ // It's possible that the task fails before `writer` gets initialized
if (writer != null) {
writer.close()
}
@@ -469,21 +503,25 @@ private[sql] class DynamicPartitionWriterContainer(
})
}
- override def commitTask(): Unit = {
- try {
+ private def clearOutputWriters(): Unit = {
+ if (outputWriters.nonEmpty) {
outputWriters.values.foreach(_.close())
outputWriters.clear()
+ }
+ }
+
+ override def commitTask(): Unit = {
+ try {
+ clearOutputWriters()
super.commitTask()
} catch { case cause: Throwable =>
- super.abortTask()
throw new RuntimeException("Failed to commit task", cause)
}
}
override def abortTask(): Unit = {
try {
- outputWriters.values.foreach(_.close())
- outputWriters.clear()
+ clearOutputWriters()
} finally {
super.abortTask()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 6165764632c29..e1c6c706242d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions._
class DataFrameJoinSuite extends QueryTest {
@@ -93,4 +94,20 @@ class DataFrameJoinSuite extends QueryTest {
left.join(right, left("key") === right("key")),
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)
}
+
+ test("broadcast join hint") {
+ val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+ val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+
+ // equijoin - should be converted into broadcast join
+ val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan
+ assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1)
+
+ // no join key -- should not be a broadcast join
+ val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan
+ assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0)
+
+ // planner should not crash without a join
+ broadcast(df1).queryExecution.executedPlan
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ba1d020f22f11..47443a917b765 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -301,7 +301,7 @@ class DataFrameSuite extends QueryTest {
)
}
- test("call udf in SQLContext") {
+ test("deprecated callUdf in SQLContext") {
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
val sqlctx = df.sqlContext
sqlctx.udf.register("simpleUdf", (v: Int) => v * v)
@@ -310,6 +310,15 @@ class DataFrameSuite extends QueryTest {
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
}
+ test("callUDF in SQLContext") {
+ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+ val sqlctx = df.sqlContext
+ sqlctx.udf.register("simpleUDF", (v: Int) => v * v)
+ checkAnswer(
+ df.select($"id", callUDF("simpleUDF", $"value")),
+ Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
+ }
+
test("withColumn") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4441afd6bd811..73bc6c999164e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-6145: special cases") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
- """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
- checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
- checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
+ """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
+ checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 520a862ea0838..207d7a352c7b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index c32d9f88dd6ee..8204a584179bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -25,7 +25,7 @@ import org.scalactic.Tolerance._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.types._
@@ -76,26 +76,28 @@ class JsonSuite extends QueryTest with TestJsonData {
checkTypePromotion(
Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
- checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
- checkTypePromotion(DateUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
enforceCorrectType(strTime, TimestampType))
val strDate = "2014-10-15"
checkTypePromotion(
- DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType))
+ DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType))
val ISO8601Time1 = "1970-01-01T01:00:01.0Z"
- checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(3601000)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)),
enforceCorrectType(ISO8601Time1, TimestampType))
- checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType))
+ checkTypePromotion(DateTimeUtils.millisToDays(3601000),
+ enforceCorrectType(ISO8601Time1, DateType))
val ISO8601Time2 = "1970-01-01T02:00:01-01:00"
- checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(10801000)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)),
enforceCorrectType(ISO8601Time2, TimestampType))
- checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType))
+ checkTypePromotion(DateTimeUtils.millisToDays(10801000),
+ enforceCorrectType(ISO8601Time2, DateType))
}
test("Get compatible type") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 284d99d4938d1..47a7be1c6a664 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -37,7 +37,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@@ -137,7 +137,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
def makeDateRDD(): DataFrame =
sqlContext.sparkContext
.parallelize(0 to 1000)
- .map(i => Tuple1(DateUtils.toJavaDate(i)))
+ .map(i => Tuple1(DateTimeUtils.toJavaDate(i)))
.toDF()
.select($"_1")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 48875773224c7..79eac930e54f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.sources
-import java.sql.{Timestamp, Date}
-
+import java.sql.{Date, Timestamp}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -84,8 +83,8 @@ case class AllDataTypesScan(
i.toDouble,
Decimal(new java.math.BigDecimal(i)),
Decimal(new java.math.BigDecimal(i)),
- DateUtils.fromJavaDate(new Date(1970, 1, 1)),
- DateUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
+ DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)),
+ DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)),
UTF8String.fromString(s"varchar_$i"),
Seq(i, i + 1),
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
@@ -93,7 +92,7 @@ case class AllDataTypesScan(
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
Row(i, i.toString),
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
- InternalRow(Seq(DateUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
+ InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 934452fe579a1..31a49a3683338 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -526,8 +526,14 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with
| rows between 2 preceding and 2 following);
""".stripMargin, reset = false)
+ // collect_set() output array in an arbitrary order, hence causes different result
+ // when running this test suite under Java 7 and 8.
+ // We change the original sql query a little bit for making the test suite passed
+ // under different JDK
createQueryTest("windowing.q -- 20. testSTATs",
"""
+ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp
+ |from (
|select p_mfgr,p_name, p_size,
|stddev(p_retailprice) over w1 as sdev,
|stddev_pop(p_retailprice) over w1 as sdev_pop,
@@ -538,6 +544,8 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with
|from part
|window w1 as (distribute by p_mfgr sort by p_mfgr, p_name
| rows between 2 preceding and 2 following)
+ |) t lateral view explode(uniq_size) d as uniq_data
+ |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp
""".stripMargin, reset = false)
createQueryTest("windowing.q -- 21. testDISTs",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index d4f1ae8ee01d9..864c888ab073d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -273,7 +273,7 @@ private[hive] trait HiveInspectors {
System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
temp
case poi: WritableConstantDateObjectInspector =>
- DateUtils.fromJavaDate(poi.getWritableConstantValue.get())
+ DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
case mi: StandardConstantMapObjectInspector =>
// take the value from the map inspector object, rather than the input data
mi.getWritableConstantValue.map { case (k, v) =>
@@ -313,13 +313,13 @@ private[hive] trait HiveInspectors {
System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
result
case x: DateObjectInspector if x.preferWritable() =>
- DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
- case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
+ DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
+ case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
case x: TimestampObjectInspector if x.preferWritable() =>
val t = x.getPrimitiveWritableObject(data)
t.getSeconds * 10000000L + t.getNanos / 100
case ti: TimestampObjectInspector =>
- DateUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
+ DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
case _ => pi.getPrimitiveJavaObject(data)
}
case li: ListObjectInspector =>
@@ -356,10 +356,10 @@ private[hive] trait HiveInspectors {
(o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)
case _: JavaDateObjectInspector =>
- (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int])
+ (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int])
case _: JavaTimestampObjectInspector =>
- (o: Any) => DateUtils.toJavaTimestamp(o.asInstanceOf[Long])
+ (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])
case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
@@ -468,9 +468,9 @@ private[hive] trait HiveInspectors {
case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a)
case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
case _: DateObjectInspector if x.preferWritable() => getDateWritable(a)
- case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int])
+ case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int])
case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a)
- case _: TimestampObjectInspector => DateUtils.toJavaTimestamp(a.asInstanceOf[Long])
+ case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long])
}
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
@@ -781,7 +781,7 @@ private[hive] trait HiveInspectors {
if (value == null) {
null
} else {
- new hiveIo.TimestampWritable(DateUtils.toJavaTimestamp(value.asInstanceOf[Long]))
+ new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]))
}
private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ca4b80b51b23f..7c4620952ba4b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -415,13 +415,6 @@ private[hive] object HiveQl {
throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
}
- protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"_c$i")()
- }
- }
-
protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = {
val (db, tableName) =
tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match {
@@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
- nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq)
+ select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 439f39bafc926..00e61e35d4354 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -29,11 +29,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters,
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
-import org.apache.spark.{Logging}
+import org.apache.spark.Logging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
@@ -362,10 +362,10 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
row.update(ordinal, HiveShim.toCatalystDecimal(oi, value))
case oi: TimestampObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
- row.setLong(ordinal, DateUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value)))
+ row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value)))
case oi: DateObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
- row.setInt(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value)))
+ row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value)))
case oi: BinaryObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) =>
row.update(ordinal, oi.getPrimitiveJavaObject(value))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 8b928861fcc70..ab75b12e2a2e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.Row
import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
-import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableJobConf
@@ -201,7 +201,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
def convertToHiveRawString(col: String, value: Any): String = {
val raw = String.valueOf(value)
schema(col).dataType match {
- case DateType => DateUtils.toString(raw.toInt)
+ case DateType => DateTimeUtils.toString(raw.toInt)
case _: DecimalType => BigDecimal(raw).toString()
case _ => raw
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala
index 1e51173a19882..e3ab9442b4821 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala
@@ -27,13 +27,13 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.hive.HiveMetastoreTypes
import org.apache.spark.sql.types.StructType
-private[orc] object OrcFileOperator extends Logging{
+private[orc] object OrcFileOperator extends Logging {
def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = {
val conf = config.getOrElse(new Configuration)
val fspath = new Path(pathStr)
val fs = fspath.getFileSystem(conf)
val orcFiles = listOrcFiles(pathStr, conf)
-
+ logDebug(s"Creating ORC Reader from ${orcFiles.head}")
// TODO Need to consider all files when schema evolution is taken into account.
OrcFile.createReader(fs, orcFiles.head)
}
@@ -42,6 +42,7 @@ private[orc] object OrcFileOperator extends Logging{
val reader = getFileReader(path, conf)
val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector]
val schema = readerInspector.getTypeName
+ logDebug(s"Reading schema from file $path, got Hive schema string: $schema")
HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType]
}
@@ -52,14 +53,14 @@ private[orc] object OrcFileOperator extends Logging{
def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
val origPath = new Path(pathStr)
val fs = origPath.getFileSystem(conf)
- val path = origPath.makeQualified(fs)
+ val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
.filterNot(_.isDir)
.map(_.getPath)
.filterNot(_.getName.startsWith("_"))
.filterNot(_.getName.startsWith("."))
- if (paths == null || paths.size == 0) {
+ if (paths == null || paths.isEmpty) {
throw new IllegalArgumentException(
s"orcFileOperator: path $path does not have valid orc files matching the pattern")
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index dbce39f21d271..705f48f1cd9f0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, Reco
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.{HadoopRDD, RDD}
@@ -39,7 +40,6 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType
import org.apache.spark.sql.sources.{Filter, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.{Logging}
import org.apache.spark.util.SerializableConfiguration
/* Implicit conversions */
@@ -105,8 +105,9 @@ private[orc] class OrcOutputWriter(
recordWriterInstantiated = true
val conf = context.getConfiguration
+ val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID")
val partition = context.getTaskAttemptID.getTaskID.getId
- val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc"
+ val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc"
new OrcOutputFormat().getRecordWriter(
new Path(path, filename).getFileSystem(conf),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index f901bd8171508..ea325cc93cb85 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -49,7 +49,7 @@ import scala.collection.JavaConversions._
object TestHive
extends TestHiveContext(
new SparkContext(
- System.getProperty("spark.sql.test.master", "local[2]"),
+ System.getProperty("spark.sql.test.master", "local[32]"),
"TestSQLContext",
new SparkConf()
.set("spark.sql.test", "")
diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609
new file mode 100644
index 0000000000000..7e5fceeddeeeb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609
@@ -0,0 +1,97 @@
+Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 2 66619.10876874991 0.811328754177887 2801.7074999999995
+Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 6 66619.10876874991 0.811328754177887 2801.7074999999995
+Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 34 66619.10876874991 0.811328754177887 2801.7074999999995
+Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 2 74912.8826888888 1.0 4128.782222222221
+Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 34 74912.8826888888 1.0 4128.782222222221
+Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 2 53315.51002399992 0.695639377397664 2210.7864
+Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 6 53315.51002399992 0.695639377397664 2210.7864
+Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 28 53315.51002399992 0.695639377397664 2210.7864
+Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 34 53315.51002399992 0.695639377397664 2210.7864
+Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 2 41099.896184 0.630785977101214 2009.9536000000007
+Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 6 41099.896184 0.630785977101214 2009.9536000000007
+Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 28 41099.896184 0.630785977101214 2009.9536000000007
+Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 34 41099.896184 0.630785977101214 2009.9536000000007
+Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 42 41099.896184 0.630785977101214 2009.9536000000007
+Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 6 14788.129118750014 0.2036684720435979 331.1337500000004
+Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 28 14788.129118750014 0.2036684720435979 331.1337500000004
+Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 34 14788.129118750014 0.2036684720435979 331.1337500000004
+Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 42 14788.129118750014 0.2036684720435979 331.1337500000004
+Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 6 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502
+Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 28 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502
+Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 42 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502
+Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 2 20231.169866666663 -0.49369526554523185 -1113.7466666666658
+Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 14 20231.169866666663 -0.49369526554523185 -1113.7466666666658
+Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 40 20231.169866666663 -0.49369526554523185 -1113.7466666666658
+Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 2 18978.662075 -0.5205630897335946 -1004.4812499999995
+Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 14 18978.662075 -0.5205630897335946 -1004.4812499999995
+Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 25 18978.662075 -0.5205630897335946 -1004.4812499999995
+Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 40 18978.662075 -0.5205630897335946 -1004.4812499999995
+Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 2 16910.329504000005 -0.46908967495720255 -766.1791999999995
+Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 14 16910.329504000005 -0.46908967495720255 -766.1791999999995
+Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 18 16910.329504000005 -0.46908967495720255 -766.1791999999995
+Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 25 16910.329504000005 -0.46908967495720255 -766.1791999999995
+Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 40 16910.329504000005 -0.46908967495720255 -766.1791999999995
+Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 2 18374.07627499999 -0.6091405874714462 -1128.1787499999987
+Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 18 18374.07627499999 -0.6091405874714462 -1128.1787499999987
+Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 25 18374.07627499999 -0.6091405874714462 -1128.1787499999987
+Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 40 18374.07627499999 -0.6091405874714462 -1128.1787499999987
+Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 2 24473.534488888927 -0.9571686373491608 -1441.4466666666676
+Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 18 24473.534488888927 -0.9571686373491608 -1441.4466666666676
+Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 25 24473.534488888927 -0.9571686373491608 -1441.4466666666676
+Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 14 38720.09628888887 0.5557168646224995 224.6944444444446
+Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 17 38720.09628888887 0.5557168646224995 224.6944444444446
+Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 19 38720.09628888887 0.5557168646224995 224.6944444444446
+Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 1 75702.81305 -0.6720833036576083 -1296.9000000000003
+Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 14 75702.81305 -0.6720833036576083 -1296.9000000000003
+Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 17 75702.81305 -0.6720833036576083 -1296.9000000000003
+Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 19 75702.81305 -0.6720833036576083 -1296.9000000000003
+Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 1 67722.117896 -0.5703526513979519 -2129.0664
+Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 14 67722.117896 -0.5703526513979519 -2129.0664
+Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 17 67722.117896 -0.5703526513979519 -2129.0664
+Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 19 67722.117896 -0.5703526513979519 -2129.0664
+Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 45 67722.117896 -0.5703526513979519 -2129.0664
+Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 1 76128.53331875012 -0.577476899644802 -2547.7868749999993
+Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 14 76128.53331875012 -0.577476899644802 -2547.7868749999993
+Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 19 76128.53331875012 -0.577476899644802 -2547.7868749999993
+Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 45 76128.53331875012 -0.577476899644802 -2547.7868749999993
+Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 1 67902.76602222225 -0.8710736366736884 -4099.731111111111
+Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 19 67902.76602222225 -0.8710736366736884 -4099.731111111111
+Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 45 67902.76602222225 -0.8710736366736884 -4099.731111111111
+Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 10 28944.25735555559 -0.6656975320098423 -1347.4777777777779
+Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 27 28944.25735555559 -0.6656975320098423 -1347.4777777777779
+Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 39 28944.25735555559 -0.6656975320098423 -1347.4777777777779
+Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 7 58693.95151875002 -0.8051852719193339 -2537.328125
+Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 10 58693.95151875002 -0.8051852719193339 -2537.328125
+Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 27 58693.95151875002 -0.8051852719193339 -2537.328125
+Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 39 58693.95151875002 -0.8051852719193339 -2537.328125
+Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 7 54802.817784000035 -0.6046935574240581 -1719.8079999999995
+Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 10 54802.817784000035 -0.6046935574240581 -1719.8079999999995
+Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 12 54802.817784000035 -0.6046935574240581 -1719.8079999999995
+Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 27 54802.817784000035 -0.6046935574240581 -1719.8079999999995
+Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 39 54802.817784000035 -0.6046935574240581 -1719.8079999999995
+Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 7 61174.24181875003 -0.5508665654707869 -1719.0368749999975
+Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 12 61174.24181875003 -0.5508665654707869 -1719.0368749999975
+Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 27 61174.24181875003 -0.5508665654707869 -1719.0368749999975
+Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 39 61174.24181875003 -0.5508665654707869 -1719.0368749999975
+Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 7 80278.40095555557 -0.7755740084632333 -1867.4888888888881
+Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 12 80278.40095555557 -0.7755740084632333 -1867.4888888888881
+Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 27 80278.40095555557 -0.7755740084632333 -1867.4888888888881
+Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 2 7005.487488888913 0.39004303087285047 418.9233333333353
+Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 6 7005.487488888913 0.39004303087285047 418.9233333333353
+Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 31 7005.487488888913 0.39004303087285047 418.9233333333353
+Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 2 100286.53662500004 -0.713612911776183 -4090.853749999999
+Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 6 100286.53662500004 -0.713612911776183 -4090.853749999999
+Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 31 100286.53662500004 -0.713612911776183 -4090.853749999999
+Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 46 100286.53662500004 -0.713612911776183 -4090.853749999999
+Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 2 81456.04997600002 -0.712858514567818 -3297.2011999999986
+Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 6 81456.04997600002 -0.712858514567818 -3297.2011999999986
+Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 23 81456.04997600002 -0.712858514567818 -3297.2011999999986
+Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 31 81456.04997600002 -0.712858514567818 -3297.2011999999986
+Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 46 81456.04997600002 -0.712858514567818 -3297.2011999999986
+Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 2 81474.56091875004 -0.984128787153391 -4871.028125000002
+Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 6 81474.56091875004 -0.984128787153391 -4871.028125000002
+Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 23 81474.56091875004 -0.984128787153391 -4871.028125000002
+Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 46 81474.56091875004 -0.984128787153391 -4871.028125000002
+Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 2 99807.08486666664 -0.9978877469246936 -5664.856666666666
+Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 23 99807.08486666664 -0.9978877469246936 -5664.856666666666
+Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 46 99807.08486666664 -0.9978877469246936 -5664.856666666666
diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838
deleted file mode 100644
index 1f7e8a5d67036..0000000000000
--- a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838
+++ /dev/null
@@ -1,26 +0,0 @@
-Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 [34,2] 74912.8826888888 1.0 4128.782222222221
-Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 [34,2,6] 66619.10876874991 0.811328754177887 2801.7074999999995
-Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 [34,2,6,28] 53315.51002399992 0.695639377397664 2210.7864
-Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 [34,2,6,42,28] 41099.896184 0.630785977101214 2009.9536000000007
-Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 [34,6,42,28] 14788.129118750014 0.2036684720435979 331.1337500000004
-Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 [6,42,28] 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502
-Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 [2,40,14] 20231.169866666663 -0.49369526554523185 -1113.7466666666658
-Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 [2,25,40,14] 18978.662075 -0.5205630897335946 -1004.4812499999995
-Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 [2,18,25,40,14] 16910.329504000005 -0.46908967495720255 -766.1791999999995
-Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 [2,18,25,40] 18374.07627499999 -0.6091405874714462 -1128.1787499999987
-Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 [2,18,25] 24473.534488888927 -0.9571686373491608 -1441.4466666666676
-Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 [17,19,14] 38720.09628888887 0.5557168646224995 224.6944444444446
-Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 [17,1,19,14] 75702.81305 -0.6720833036576083 -1296.9000000000003
-Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 [17,1,19,14,45] 67722.117896 -0.5703526513979519 -2129.0664
-Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 [1,19,14,45] 76128.53331875012 -0.577476899644802 -2547.7868749999993
-Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 [1,19,45] 67902.76602222225 -0.8710736366736884 -4099.731111111111
-Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 [39,27,10] 28944.25735555559 -0.6656975320098423 -1347.4777777777779
-Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 [39,7,27,10] 58693.95151875002 -0.8051852719193339 -2537.328125
-Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 [39,7,27,10,12] 54802.817784000035 -0.6046935574240581 -1719.8079999999995
-Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 [39,7,27,12] 61174.24181875003 -0.5508665654707869 -1719.0368749999975
-Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 [7,27,12] 80278.40095555557 -0.7755740084632333 -1867.4888888888881
-Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 [2,6,31] 7005.487488888913 0.39004303087285047 418.9233333333353
-Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 [2,6,46,31] 100286.53662500004 -0.713612911776183 -4090.853749999999
-Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 [2,23,6,46,31] 81456.04997600002 -0.712858514567818 -3297.2011999999986
-Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 [2,23,6,46] 81474.56091875004 -0.984128787153391 -4871.028125000002
-Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 [2,23,46] 99807.08486666664 -0.9978877469246936 -5664.856666666666
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 82e08caf46457..a0cdd0db42d65 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -43,8 +43,14 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
orcTableDir.mkdir()
import org.apache.spark.sql.hive.test.TestHive.implicits._
+ // Originally we were using a 10-row RDD for testing. However, when default parallelism is
+ // greater than 10 (e.g., running on a node with 32 cores), this RDD contains empty partitions,
+ // which result in empty ORC files. Unfortunately, ORC doesn't handle empty files properly and
+ // causes build failure on Jenkins, which happens to have 32 cores. Please refer to SPARK-8501
+ // for more details. To workaround this issue before fixing SPARK-8501, we simply increase row
+ // number in this RDD to avoid empty partitions.
sparkContext
- .makeRDD(1 to 10)
+ .makeRDD(1 to 100)
.map(i => OrcData(i, s"part-$i"))
.toDF()
.registerTempTable(s"orc_temp_table")
@@ -70,35 +76,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
}
test("create temporary orc table") {
- checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10))
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(100))
checkAnswer(
sql("SELECT * FROM normal_orc_source"),
- (1 to 10).map(i => Row(i, s"part-$i")))
+ (1 to 100).map(i => Row(i, s"part-$i")))
checkAnswer(
sql("SELECT * FROM normal_orc_source where intField > 5"),
- (6 to 10).map(i => Row(i, s"part-$i")))
+ (6 to 100).map(i => Row(i, s"part-$i")))
checkAnswer(
sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
- (1 to 10).map(i => Row(1, s"part-$i")))
+ (1 to 100).map(i => Row(1, s"part-$i")))
}
test("create temporary orc table as") {
- checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10))
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(100))
checkAnswer(
sql("SELECT * FROM normal_orc_source"),
- (1 to 10).map(i => Row(i, s"part-$i")))
+ (1 to 100).map(i => Row(i, s"part-$i")))
checkAnswer(
sql("SELECT * FROM normal_orc_source WHERE intField > 5"),
- (6 to 10).map(i => Row(i, s"part-$i")))
+ (6 to 100).map(i => Row(i, s"part-$i")))
checkAnswer(
sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
- (1 to 10).map(i => Row(1, s"part-$i")))
+ (1 to 100).map(i => Row(1, s"part-$i")))
}
test("appending insert") {
@@ -106,7 +112,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM normal_orc_source"),
- (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i =>
+ (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 100).flatMap { i =>
Seq.fill(2)(Row(i, s"part-$i"))
})
}
@@ -119,7 +125,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM normal_orc_as_source"),
- (6 to 10).map(i => Row(i, s"part-$i")))
+ (6 to 100).map(i => Row(i, s"part-$i")))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 0f959b3d0b86d..5d7cd16c129cd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -53,9 +53,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
numberFormat.setGroupingUsed(false)
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
val split = context.getTaskAttemptID.getTaskID.getId
val name = FileOutputFormat.getOutputName(context)
- new Path(outputFile, s"$name-${numberFormat.format(split)}-${UUID.randomUUID()}")
+ new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId")
}
}
@@ -156,6 +157,7 @@ class CommitFailureTestRelation(
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
override def close(): Unit = {
+ super.close()
sys.error("Intentional task commitment failure for testing purpose.")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 76469d7a3d6a5..e0d8277a8ed3f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -35,7 +35,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
import sqlContext.sql
import sqlContext.implicits._
- val dataSourceName = classOf[SimpleTextSource].getCanonicalName
+ val dataSourceName: String
val dataSchema =
StructType(
@@ -470,6 +470,33 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect())
}
}
+
+ // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores
+ // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or
+ // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this
+ // requirement. We probably want to move this test case to spark-integration-tests or spark-perf
+ // later.
+ test("SPARK-8406: Avoids name collision while writing Parquet files") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ sqlContext
+ .range(10000)
+ .repartition(250)
+ .write
+ .mode(SaveMode.Overwrite)
+ .format(dataSourceName)
+ .save(path)
+
+ assertResult(10000) {
+ sqlContext
+ .read
+ .format(dataSourceName)
+ .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json)
+ .load(path)
+ .count()
+ }
+ }
+ }
}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -502,15 +529,17 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
}
class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
- import TestHive.implicits._
-
override val sqlContext = TestHive
+ // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
withTempPath { file =>
- val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b")
+ // Here we coalesce partition number to 1 to ensure that only a single task is issued. This
+ // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
+ // directory while committing/aborting the job. See SPARK-8513 for more details.
+ val df = sqlContext.range(0, 10).coalesce(1)
intercept[SparkException] {
df.write.format(dataSourceName).save(file.getCanonicalPath)
}