diff --git a/.gitignore b/.gitignore index d162fa9cca994..d54d21b802be8 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 8c61e67a0c7d1..8aca5a7f7a967 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -67,3 +67,5 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +DESCRIPTION +NAMESPACE diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000000000..9a5889ba28b2a --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 0000000000000..931d01549b265 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000000000..a6970e39b55f3 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 0000000000000..3f889c0ca3d1e --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 0000000000000..4194172a2e115 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 0000000000000..008a5c668bc45 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 0000000000000..55ed6f4be1a4a --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 0000000000000..701adb2a3da1d --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 0000000000000..052f68c6c24e2 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'RDD.R' + 'pairRDD.R' + 'SQLTypes.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'backend.R' + 'broadcast.R' + 'client.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 0000000000000..a354cdce74afa --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,182 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "subtract", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("print.structType", + "print.structField") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 0000000000000..044fdb4d01223 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column") + cols <- columns(x) + if (name %in% cols) { + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' Subtract +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the subtract operation. +#' @rdname subtract +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' subtractDF <- subtract(df, df2) +#' } +setMethod("subtract", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + subtracted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(subtracted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 0000000000000..820027ef67e3b --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1531 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- cleanClosure(func) + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(split, iterator) { + func(split, prev@func(split, iterator)) + } + .Object@func <- cleanClosure(pipelinedFunc) + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(rdd@func, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "splits") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(split, iterator) { + lapply(iterator, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { +#' split * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + PipelinedRDD(X, FUN) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = getSerializedMode(x)) + # TODO: Check if this append is O(n^2)? + resList <- append(resList, elems) + } + resList + }) + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(split) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numToInt(numPartitions), TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(s, part) { + set.seed(s) # split as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the splits of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + list(part[ord[1:num]]) + } else { + list(part) + } + } + + reduceFunc <- function(elems, part) { + newElems <- append(elems, part) + # R limitation: order works only on primitive types! + ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) + newElems[ord[1:num]] + } + + newRdd <- mapPartitions(x, partitionFunc) + reduce(newRdd, reduceFunc) +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(split, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + split) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(split, part) { + if (split == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[split]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # check if corresponding partitions of both RDDs have the same number of elements. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + part[[length(part) + 1]] <- length(part) + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + + zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) + # The zippedRDD's elements are of scala Tuple2 type. The serialized + # flag Here is used for the elements inside the tuples. + serializerMode <- getSerializedMode(x) + zippedRDD <- RDD(zippedJRDD, serializerMode) + + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # check if corresponding partitions of both RDDs have the same number of elements. + if (lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + keys <- list() + values <- list() + } + } else { + # Keys, values must have same length here, because this has + # been validated inside the JavaRDD.zip() function. + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { + list(k, v) + }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(zippedRDD, partitionFunc) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 0000000000000..930ada22f4c38 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,520 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R new file mode 100644 index 0000000000000..962fba5b3cf03 --- /dev/null +++ b/R/pkg/R/SQLTypes.R @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions for handling SparkSQL DataTypes. + +# Handler for StructType +structType <- function(st) { + obj <- structure(new.env(parent = emptyenv()), class = "structType") + obj$jobj <- st + obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } + obj +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + fieldsList <- lapply(x$fields(), function(i) { i$print() }) + print(fieldsList) +} + +# Handler for StructField +structField <- function(sf) { + obj <- structure(new.env(parent = emptyenv()), class = "structField") + obj$jobj <- sf + obj$name <- function() { callJMethod(sf, "name") } + obj$dataType <- function() { callJMethod(sf, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(sf, "nullable") } + obj$print <- function() { paste("StructField(", + paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), + ")", sep = "") } + obj +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat(x$print()) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 0000000000000..2fb6fae55f28c --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 0000000000000..583fa2e7fdcfd --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 0000000000000..1281c41213e32 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 0000000000000..b282001d8b6b5 --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R SQLTypes.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 0000000000000..ebbb8fba1052d --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinSplits <- function(sc, minSplits) { + if (is.null(minSplits)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minSplits <- min(defaultParallelism, 2) + } + as.integer(minSplits) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 0000000000000..257b435607ce8 --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 0000000000000..5fb1ccaa84ee2 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") +}) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname subtract +#' @export +setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 0000000000000..855fbdfc7c4ca --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R SQLTypes.R column.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 0000000000000..a8a25230b636d --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +#' @include generics.R +NULL + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 0000000000000..5d64822859d1f --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,787 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Array[Byte], + # Array[Byte])], where the key is the hashed split, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + as.integer(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + as.integer(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "integer"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "integer"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "integer"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 0000000000000..8a9c0c652ce24 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + writeInt(con, as.integer(nchar(value) + 1)) + writeBin(value, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} + +writeStrings <- function(con, stringList) { + writeLines(unlist(stringList), con) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 0000000000000..bc82df01f0fff --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 0000000000000..c337fb0751e72 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,467 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..80d796d467943 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 0000000000000..8fe711b622086 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 0000000000000..7a7f2031152a0 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 0000000000000..ca4218f3819f8 --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 0000000000000..c15553ba28517 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 0000000000000..fee91a427d6d5 --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 0000000000000..e4aab37436a74 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 0000000000000..8152b448d0870 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 0000000000000..fff028657db37 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 0000000000000..b76e4db03e715 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,645 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(split, part) { + part <- as.list(unlist(part) * split + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(split, part) { + part <- as.list(unlist(part) * split) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + tempDir <- tempfile(pattern = "checkpoint") + setCheckpointDir(sc, tempDir) + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink(tempDir) +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(splitIndex, part) { list(splitIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 0000000000000..d1da8232aea81 --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 0000000000000..cf5cf6d1692af --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,695 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), subtract(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + subtracted <- sortDF(subtract(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(subtracted) == 2) + expect_true(first(subtracted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 0000000000000..7f4c7c315d787 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 0000000000000..6b87b4b3e0b08 --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1L) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1L) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 0000000000000..9c5bb427932b4 --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 0000000000000..3584b418a71a9 --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 0000000000000..c6542928e8ddd --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + output <- computeFunc(partition, data) + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + SparkR:::writeStrings(outputCon, output) + } + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + } +} + +# End of output +if (serializer %in% c("byte", "row")) { + SparkR:::writeInt(outputCon, 0L) +} + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 0000000000000..a55a56fe80e10 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 0000000000000..aa486d8228371 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 0000000000000..e3274b9a0c547 --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 0000000000000..4f8a1ed2d83ef --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 0000000000000..e82ad0ba2cd06 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 0000000000000..36d932c453b6f --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 2d7070c25d328..95779e9ddbb18 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -41,8 +42,8 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$SPARK_HOME/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$SPARK_HOME/assembly/target/scala-2.10" + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 4f5eb5e20614d..09b4149c2a439 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2d..c3e0221fb62e3 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class b/bin/spark-class index c03946d92e2e4..c49d97ce5cf25 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -82,13 +82,22 @@ if [ $(command -v "$JAR_CMD") ] ; then fi fi +LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" + +# Add the launcher build dir to the classpath if requested. +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" +fi + +export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$SPARK_ASSEMBLY_JAR" org.apache.spark.launcher.Main "$@") +done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") if [ "${CMD[0]}" = "usage" ]; then "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 4ce727bc99128..3d068dd3a2739 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if "x%1"=="x" ( @@ -47,13 +46,22 @@ if "%SPARK_ASSEMBLY_JAR%"=="0" ( exit /b 1 ) +set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% + +rem Add the launcher build dir to the classpath if requested. +if not "x%SPARK_PREPEND_CLASSES%"=="x" ( + set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% +) + +set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% + rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %SPARK_ASSEMBLY_JAR% org.apache.spark.launcher.Main %*"') do ( +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( set SPARK_CMD=%%i ) %SPARK_CMD% diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 0000000000000..8c918e2b09aef --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 0000000000000..d7b60183ca8e0 --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 0000000000000..e47f22c7300bb --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f61..3a2a88219818f 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index 6cd1965ec37c2..e80829b7a7f3d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -442,4 +442,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..3a2a88219818f 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 14ba37d7c9bd9..013db8df9b363 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -30,7 +30,7 @@ $(function() { stripeSummaryTable(); - $("input:checkbox").click(function() { + $('input[type="checkbox"]').click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); stripeSummaryTable(); @@ -39,15 +39,15 @@ $(function() { $("#select-all-metrics").click(function() { if (this.checked) { // Toggle all un-checked options. - $('input:checkbox:not(:checked)').trigger('click'); + $('input[type="checkbox"]:not(:checked)').trigger('click'); } else { // Toggle all checked options. - $('input:checkbox:checked').trigger('click'); + $('input[type="checkbox"]:checked').trigger('click'); } }); // Trigger a click on the checkbox if a user clicks the label next to it. $("span.additional-metric-title").click(function() { - $(this).parent().find('input:checkbox').trigger('click'); + $(this).parent().find('input[type="checkbox"]').trigger('click'); }); }); diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 9b05c9623b704..715b259057569 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.util.Utils /** @@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask private case class CleanAccum(accId: Long) extends CleanupTask +private case class CleanCheckpoint(rddId: Int) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -94,12 +95,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { + def attachListener(listener: CleanerListener): Unit = { listeners += listener } /** Start the cleaner. */ - def start() { + def start(): Unit = { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() @@ -108,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Stop the cleaning thread and wait until the thread has finished running its current task. */ - def stop() { + def stop(): Unit = { stopped = true // Interrupt the cleaning thread, but wait until the current task has finished before // doing so. This guards against the race condition where a cleaning thread may @@ -121,7 +122,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a RDD for cleanup when it is garbage collected. */ - def registerRDDForCleanup(rdd: RDD[_]) { + def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } @@ -130,17 +131,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } /** Register a Broadcast for cleanup when it is garbage collected. */ - def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = { registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + /** Register a RDDCheckpointData for cleanup when it is garbage collected. */ + def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = { + registerForCleanup(rdd, CleanCheckpoint(parentId)) + } + /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } @@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) case CleanAccum(accId) => doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) } } } @@ -175,7 +183,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform RDD cleanup. */ - def doCleanupRDD(rddId: Int, blocking: Boolean) { + def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) @@ -187,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform shuffle cleanup, asynchronously. */ - def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -200,7 +208,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) @@ -212,7 +220,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform accumulator cleanup. */ - def doCleanupAccum(accId: Long, blocking: Boolean) { + def doCleanupAccum(accId: Long, blocking: Boolean): Unit = { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) @@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } + /** Perform checkpoint cleanup. */ + def doCleanCheckpoint(rddId: Int): Unit = { + try { + logDebug("Cleaning rdd checkpoint data " + rddId) + RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + logInfo("Cleaned rdd checkpoint data " + rddId) + } + catch { + case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 21c6e6ffa6666..4e7bf51fc0622 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,10 +17,12 @@ package org.apache.spark +import java.util.concurrent.{Executors, TimeUnit} + import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -78,16 +80,16 @@ private[spark] class ExecutorAllocationManager( Integer.MAX_VALUE) // How long there must be backlogged tasks for before an addition is triggered (seconds) - private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 5) + private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") - // Same as above, but used only after `schedulerBacklogTimeout` is exceeded - private val sustainedSchedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded + private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", s"${schedulerBacklogTimeoutS}s") // How long an executor must be idle for before it is removed (seconds) - private val executorIdleTimeout = conf.getLong( - "spark.dynamicAllocation.executorIdleTimeout", 600) + private val executorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.executorIdleTimeout", "600s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -129,6 +131,10 @@ private[spark] class ExecutorAllocationManager( // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -144,14 +150,14 @@ private[spark] class ExecutorAllocationManager( throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } - if (schedulerBacklogTimeout <= 0) { + if (schedulerBacklogTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") } - if (sustainedSchedulerBacklogTimeout <= 0) { + if (sustainedSchedulerBacklogTimeoutS <= 0) { throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeout <= 0) { + if (executorIdleTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") } // Require external shuffle service for dynamic allocation @@ -173,32 +179,24 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** @@ -264,8 +262,8 @@ private[spark] class ExecutorAllocationManager( } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") + addTime += sustainedSchedulerBacklogTimeoutS * 1000 delta } else { 0 @@ -353,7 +351,7 @@ private[spark] class ExecutorAllocationManager( val removeRequestAcknowledged = testing || client.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeoutS seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -409,8 +407,8 @@ private[spark] class ExecutorAllocationManager( private def onSchedulerBacklogged(): Unit = synchronized { if (addTime == NOT_SET) { logDebug(s"Starting timer to add executors because pending tasks " + - s"are building up (to expire in $schedulerBacklogTimeout seconds)") - addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + s"are building up (to expire in $schedulerBacklogTimeoutS seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeoutS * 1000 } } @@ -433,8 +431,8 @@ private[spark] class ExecutorAllocationManager( if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 715f292f03469..e3bd16f1cbf24 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,15 +17,15 @@ package org.apache.spark -import scala.concurrent.duration._ -import scala.collection.mutable +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} -import akka.actor.{Actor, Cancellable} +import scala.collection.mutable import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -37,6 +37,12 @@ private[spark] case class Heartbeat( taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + private[spark] case object ExpireDeadHosts private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) @@ -44,36 +50,68 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext) + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv + + private[spark] var scheduler: TaskScheduler = null // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] + + // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses + // "milliseconds" + private val slaveTimeoutMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") + private val executorTimeoutMs = + sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout", - sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000 - - private val checkTimeoutIntervalMs = sc.conf.getLong("spark.network.timeoutInterval", - sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60)) * 1000 - - private var timeoutCheckingTask: Cancellable = null + // "spark.network.timeoutInterval" uses "seconds", while + // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" + private val timeoutIntervalMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") + private val checkTimeoutIntervalMs = + sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - override def preStart(): Unit = { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts) - super.preStart() + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + + private val killExecutorThread = Executors.newSingleThreadExecutor( + Utils.namedThreadFactory("kill-executor-thread")) + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + + override def receive: PartialFunction[Any, Unit] = { case ExpireDeadHosts => expireDeadHosts() + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + executorLastSeen(executorId) = System.currentTimeMillis() + context.reply(response) + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } private def expireDeadHosts(): Unit = { @@ -84,19 +122,27 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule logWarning(s"Removing executor $executorId with no recent heartbeats: " + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + - "timed out after ${now - lastSeenMs} ms")) + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { - sc.killExecutor(executorId) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) } executorLastSeen.remove(executorId) } } } - override def postStop(): Unit = { + override def onStop(): Unit = { if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - super.postStop() + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() } } + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" +} diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 09a9ccc226721..8de3a6c04df34 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -160,7 +160,7 @@ private[spark] class HttpServer( throw new ServerStateException("Server is not started") } else { val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" - s"$scheme://${Utils.localIpAddress}:$port" + s"$scheme://${Utils.localHostNameForURI()}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c9426c5de23a2..d65c94e410662 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,11 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithReply[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -350,17 +345,22 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 0c123c96b8d7b..390e631647bd6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -174,6 +174,42 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then seconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsSeconds(key: String): Long = { + Utils.timeStringAsSeconds(get(key)) + } + + /** + * Get a time parameter as seconds, falling back to a default if not set. If no + * suffix is provided then seconds are assumed. + * + */ + def getTimeAsSeconds(key: String, defaultValue: String): Long = { + Utils.timeStringAsSeconds(get(key, defaultValue)) + } + + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsMs(key: String): Long = { + Utils.timeStringAsMs(get(key)) + } + + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. + */ + def getTimeAsMs(key: String, defaultValue: String): Long = { + Utils.timeStringAsMs(get(key, defaultValue)) + } + + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a70be16f77eeb..e106c5c4bef60 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -31,8 +31,7 @@ import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} - -import akka.actor.Props +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -48,12 +47,14 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -95,10 +96,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - @volatile private var stopped: Boolean = false + private val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("Cannot call methods on a stopped SparkContext") } } @@ -193,8 +194,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - private[spark] val conf = config.clone() - conf.validateSettings() + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -202,63 +237,24 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } - - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") - - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { - if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) - } else { - None - } - } - private[spark] val eventLogCodec: Option[String] = { - val compress = conf.getBoolean("spark.eventLog.compress", false) - if (compress && isEventLogEnabled) { - Some(CompressionCodec.getCodecName(conf)).map(CompressionCodec.getShortName) - } else { - None - } - } + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val isLocal = (master == "local" || master.startsWith("local[")) - - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - - // Create the Spark execution environment (cache, map output tracker, etc) - // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( conf: SparkConf, @@ -267,8 +263,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.createDriverEnv(conf, isLocal, listenerBus) } - private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -276,35 +271,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener + def statusTracker: SparkStatusTracker = _statusTracker - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - val statusTracker = new SparkStatusTracker(this) - - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } - - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } - - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + private[spark] def ui: Option[SparkUI] = _ui /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. @@ -312,127 +286,248 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + def hadoopConfiguration: Configuration = _hadoopConfiguration + + private[spark] def executorMemory: Int = _executorMemory + + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() + + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb } - if (files != null) { - files.foreach(addFile) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts } + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds + } + + def applicationId: String = _applicationId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } + + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) + try { + _conf = config.clone() + _conf.validateSettings() - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v - } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Utils.getCurrentUserName() - executorEnvs("SPARK_USER") = sparkUser - - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver") - @volatile private[spark] var dagScheduler: DAGScheduler = _ - try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") + + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + + _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten + + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } + + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val applicationId: String = taskScheduler.applicationId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - val metricsSystem = env.metricsSystem + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = - new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager,appName)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + if (files != null) { + files.foreach(addFile) } - } - cleaner.foreach(_.start()) - setupAndStartListenerBus() - postEnvironmentUpdate() - postApplicationStart() + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(512) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value + } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.send(TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _eventLogDir.get, _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - private[spark] var checkpointDir: Option[String] = None + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + _executorAllocationManager = + if (dynamicAllocationEnabled) { + assert(supportDynamicAllocation, + "Dynamic allocation of executors is currently only supported in YARN mode") + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -446,10 +541,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -474,9 +571,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -537,19 +631,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -1138,8 +1219,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return whether dynamically adjusting the amount of resources allocated to * this application is supported. This is currently only available for YARN. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || dynamicAllocationTesting + private[spark] def supportDynamicAllocation = + master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) /** * :: DeveloperApi :: @@ -1156,7 +1237,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This is currently only supported in YARN mode. Return whether the request is received. */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1392,32 +1473,46 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - if (!stopped) { - stopped = true - postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") - } + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return + } + + postApplicationEnd() + _ui.foreach(_.stop()) + if (env != null) { + env.metricsSystem.report() + } + if (metadataCleaner != null) { + metadataCleaner.cancel() + } + _cleaner.foreach(_.stop()) + _executorAllocationManager.foreach(_.stop()) + if (_dagScheduler != null) { + _dagScheduler.stop() + _dagScheduler = null } + if (_listenerBusStarted) { + listenerBus.stop() + _listenerBusStarted = false + } + _eventLogger.foreach(_.stop()) + if (env != null && _heartbeatReceiver != null) { + env.rpcEnv.stop(_heartbeatReceiver) + } + _progressBar.foreach(_.stop()) + _taskScheduler = null + // TODO: Cache.stop()? + if (_env != null) { + _env.stop() + SparkEnv.set(null) + } + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1479,7 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite @@ -1740,6 +1835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } listenerBus.start(this) + _listenerBusStarted = true } /** Post the application start event */ @@ -1892,7 +1988,17 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" + + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward @@ -2133,7 +2239,7 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) @@ -2145,7 +2251,7 @@ object SparkContext extends Logging { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -2155,7 +2261,7 @@ object SparkContext extends Logging { // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2a0c7e756dd3a..0171488e09562 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -34,12 +33,14 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -54,7 +55,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -71,6 +72,9 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -91,7 +95,8 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - actorSystem.shutdown() + rpcEnv.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -236,16 +241,15 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -281,12 +285,14 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + rpcEnv.setupEndpoint(name, endpointCreator) } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) + RpcUtils.makeDriverRef(name, conf, rpcEnv) } } @@ -298,9 +304,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -320,12 +326,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) @@ -377,13 +384,13 @@ object SparkEnv extends Logging { val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } - val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", - new OutputCommitCoordinatorActor(outputCommitCoordinator)) - outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6eb4537d10477..2ec42d3aea169 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.spark.executor.CommitDeniedException import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -104,55 +103,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - - // Called after we have decided to commit - def performCommit(): Unit = { - try { - cmtr.commitTask(taCtxt) - logInfo (s"$taID: Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - - // First, check whether the task's output has already been committed by some other attempt - if (cmtr.needsTaskCommit(taCtxt)) { - // The task output needs to be committed, but we don't know whether some other task attempt - // might be racing to commit the same output partition. Therefore, coordinate with the driver - // in order to determine whether this attempt can commit (see SPARK-4879). - val shouldCoordinateWithDriver: Boolean = { - val sparkConf = SparkEnv.get.conf - // We only need to coordinate with the driver if there are multiple concurrent task - // attempts, which should only occur if speculation is enabled - val speculationEnabled = sparkConf.getBoolean("spark.speculation", false) - // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs - sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) - } - if (shouldCoordinateWithDriver) { - val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID) - if (canCommit) { - performCommit() - } else { - val msg = s"$taID: Not committed because the driver did not authorize commit" - logInfo(msg) - // We need to abort the task so that the driver can reschedule new attempts, if necessary - cmtr.abortTask(taCtxt) - throw new CommitDeniedException(msg, jobID, splitID, attemptID) - } - } else { - // Speculation is disabled or a user has chosen to manually bypass the commit coordination - performCommit() - } - } else { - // Some other attempt committed the output, so we do nothing and signal success - logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}") - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index a023712be1166..8441bb3a3047e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -661,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 18ccd625fc8d1..db4e996feb31c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -192,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8da42934a7d96..8bf0627fc420d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -93,7 +94,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -109,7 +110,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -119,7 +120,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -129,8 +130,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -139,8 +140,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -148,7 +149,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -157,7 +160,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -166,8 +171,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -175,7 +182,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -184,7 +193,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -194,7 +205,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -277,8 +290,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -441,8 +456,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19f4c95fcad74..b1ffba4c546bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -605,7 +605,6 @@ private[spark] object PythonRDD extends Logging { */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1) - serverSocket.setReuseAddress(true) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -615,9 +614,9 @@ private[spark] object PythonRDD extends Logging { try { val sock = serverSocket.accept() val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { + Utils.tryWithSafeFinally { writeIteratorToStream(items, out) - } finally { + } { out.close() } } catch { @@ -863,9 +862,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 0000000000000..3a2c94bd9d875 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 0000000000000..0075d963711f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 0000000000000..5fa4d483b8342 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + val dataStream = openDataStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def openDataStream(input: InputStream): Closeable + + protected def read(): U +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): (Int, Array[Byte]) = { + try { + val length = dataStream.readInt() + + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null // End of input + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): Array[Byte] = { + try { + val length = dataStream.readInt() + + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj, 0, length) + obj + case _ => null + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: BufferedReader = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new BufferedReader(new InputStreamReader(input)) + dataStream + } + + override protected def read(): String = { + try { + dataStream.readLine() + } catch { + case e: IOException => { + throw new SparkException("R worker exited unexpectedly (crashed)", e) + } + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 0000000000000..ccb2a371f4e48 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + val asciiBytes = new Array[Byte](len) + in.readFully(asciiBytes) + assert(asciiBytes(len - 1) == 0) + val str = new String(asciiBytes.dropRight(1).map(_.toChar)) + str + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 74ccfa6d3c9a3..4457c75e8b0fc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -165,7 +165,7 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,10 +175,13 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } @@ -212,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 3d0d68de8f495..ae99432f5ce86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,15 +17,18 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], - val memoryPerSlave: Int, + val memoryPerExecutorMB: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None, + val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) - val eventLogCodec: Option[String] = None) + val eventLogCodec: Option[String] = None, + val coresPerExecutor: Option[Int] = None) extends Serializable { val user = System.getProperty("user.name", "") @@ -33,13 +36,13 @@ private[spark] class ApplicationDescription( def copy( name: String = name, maxCores: Option[Int] = maxCores, - memoryPerSlave: Int = memoryPerSlave, + memoryPerExecutorMB: Int = memoryPerExecutorMB, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir, + eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( - name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) + name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 65238af2caa24..8d13b2a2cd4f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -89,7 +89,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index dfc5b97e6a6c8..2954f932b4f41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -46,7 +46,7 @@ private[deploy] object JsonProtocol { ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ ("duration" -> obj.duration) @@ -55,7 +55,7 @@ private[deploy] object JsonProtocol { def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ + ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 3ab425aab84c8..f0e77c2ba982b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -53,7 +53,7 @@ class LocalSparkCluster( /* Start the Master */ val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort val masters = Array(masterUrl) /* Start the Workers */ diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 0000000000000..e99779f299785 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index c2568eb4b60ac..cfaebf9ea5050 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -24,11 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -201,6 +200,37 @@ class SparkHadoopUtil extends Logging { val baseStatus = fs.getFileStatus(basePath) if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored + + /** + * Substitute variables by looking them up in Hadoop configs. Only variables that match the + * ${hadoopconf- .. } pattern are substituted. + */ + def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { + text match { + case HADOOP_CONF_PATTERN(matched) => { + logDebug(text + " matched " + HADOOP_CONF_PATTERN) + val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } + val eval = Option[String](hadoopConf.get(key)) + .map { value => + logDebug("Substituted " + matched + " with " + value) + text.replace(matched, value) + } + if (eval.isEmpty) { + // The variable was not found in Hadoop configs, so return text as is. + text + } else { + // Continue to substitute more variables. + substituteHadoopVariables(eval.get, hadoopConf) + } + } + case _ => { + logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) + text + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 660307d19eab4..296a0764b8baf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -77,6 +77,7 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -284,6 +285,13 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => @@ -291,6 +299,9 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -317,11 +328,32 @@ object SparkSubmit { } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython && isYarnCluster) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } } // Special flag to avoid deprecation warnings at the client @@ -374,6 +406,8 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options + OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, @@ -405,8 +439,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -447,6 +481,10 @@ object SparkSubmit { childArgs += ("--py-files", pyFilesNames) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -591,15 +629,15 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[deploy] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** @@ -619,12 +657,19 @@ object SparkSubmit { /** * Return whether the given primary resource requires running python. */ - private[deploy] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL + } + + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL } - private[deploy] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6eb73c43470a5..faa8780288ea3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null @@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S opt } isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) false } @@ -480,10 +482,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | + | Spark standalone and YARN only: + | --executor-cores NUM Number of cores per executor. (Default: 1 in YARN mode, + | or all available cores on the worker in standalone mode) + | | YARN-only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). - | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index c1c4812f17fbe..40835b9550586 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -46,7 +46,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 80c9c13ddec1e..9d40d8c8fd7a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -118,7 +118,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6e432d63c6b5a..3781b4e8c12bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -90,6 +90,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") ++ appTable + } else if (requestedIncomplete) { +

No incomplete applications found!

} else {

No completed applications found!

++

Did you specify the correct logging directory? diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index bc5b293379f2b..f59d550d4f3b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -75,9 +75,11 @@ private[deploy] class ApplicationInfo( } } - private[master] def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): - ExecutorDesc = { - val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + private[master] def addExecutor( + worker: WorkerInfo, + cores: Int, + useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerExecutorMB) executors(exec.id) = exec coresGranted += cores exec diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 32499b3a784a1..f459ed5b3a1a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import akka.serialization.Serialization import org.apache.spark.Logging +import org.apache.spark.util.Utils /** @@ -59,9 +60,9 @@ private[master] class FileSystemPersistenceEngine( val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { out.write(serialized) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9a5d5877da86d..c5a6b1beac9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -524,52 +524,28 @@ private[master] class Master( } /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - private def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. + * Schedule executors to be launched on the workers. + * + * There are two modes of launching executors. The first attempts to spread out an application's + * executors on as many workers as possible, while the second does the opposite (i.e. launch them + * on as few workers as possible). The former is usually better for data locality purposes and is + * the default. + * + * The number of cores assigned to each executor is configurable. When this is explicitly set, + * multiple executors from the same application may be launched on the same worker if the worker + * has enough cores and memory. Otherwise, each executor grabs all the cores available on the + * worker by default, in which case only one executor may be launched on each worker. */ - private def schedule() { - if (state != RecoveryState.ALIVE) { return } - - // First schedule drivers, they take strict precedence over applications - // Randomization helps balance drivers - val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val numWorkersAlive = shuffledAliveWorkers.size - var curPos = 0 - - for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers - // We assign workers to each waiting driver in a round-robin fashion. For each driver, we - // start from the last worker that was assigned a driver, and continue onwards until we have - // explored all alive workers. - var launched = false - var numWorkersVisited = 0 - while (numWorkersVisited < numWorkersAlive && !launched) { - val worker = shuffledAliveWorkers(curPos) - numWorkersVisited += 1 - if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { - launchDriver(worker, driver) - waitingDrivers -= driver - launched = true - } - curPos = (curPos + 1) % numWorkersAlive - } - } - + private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores + // Try to spread out each app among all the workers, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -582,32 +558,61 @@ private[master] class Master( pos = (pos + 1) % numUsable } // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) - app.state = ApplicationState.RUNNING - } + for (pos <- 0 until numUsable if assigned(pos) > 0) { + allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } } else { - // Pack each app into as few nodes as possible until we've assigned all its cores + // Pack each app into as few workers as possible until we've assigned all its cores for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) - app.state = ApplicationState.RUNNING - } - } + allocateWorkerResourceToExecutors(app, app.coresLeft, worker) + } + } + } + } + + /** + * Allocate a worker's resources to one or more executors. + * @param app the info of the application which the executors belong to + * @param coresToAllocate cores on this worker to be allocated to this application + * @param worker the worker info + */ + private def allocateWorkerResourceToExecutors( + app: ApplicationInfo, + coresToAllocate: Int, + worker: WorkerInfo): Unit = { + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) + var coresLeft = coresToAllocate + while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { + val exec = app.addExecutor(worker, coresPerExecutor) + coresLeft -= coresPerExecutor + launchExecutor(worker, exec) + app.state = ApplicationState.RUNNING + } + } + + /** + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. + */ + private def schedule(): Unit = { + if (state != RecoveryState.ALIVE) { return } + // Drivers take strict precedence over executors + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver } } } + startExecutorsOnWorkers() } - private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { + private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 761aa8f7b1ef6..273f077bd8f57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -94,7 +94,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

  • Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 45412a35e9a7d..399f07399a0aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -208,8 +208,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.coresGranted} - - {Utils.megabytesToString(app.desc.memoryPerSlave)} + + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} {UIUtils.formatDate(app.submitDate)} {app.desc.user} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 420442f7564cc..b8fd406fb6f9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils /** * A client that submits applications to the standalone Master using a REST protocol. @@ -148,8 +149,11 @@ private[deploy] class StandaloneRestClient extends Logging { conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(json.getBytes(Charsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } readResponse(conn) } @@ -241,7 +245,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } else { val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") - logError("Application submission failed" + failMessage) + logError(s"Application submission failed$failMessage") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e0948e16ef354..ef7a703bffe67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -24,14 +24,14 @@ import scala.collection.JavaConversions._ import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.util.{Utils, Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. @@ -44,7 +44,8 @@ private[deploy] class DriverRunner( val sparkHome: File, val driverDesc: DriverDescription, val worker: ActorRef, - val workerUrl: String) + val workerUrl: String, + val securityManager: SecurityManager) extends Logging { @volatile private var process: Option[Process] = None @@ -136,12 +137,9 @@ private[deploy] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val jarFileSystem = jarPath.getFileSystem(hadoopConf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) val jarFileName = jarPath.getName val localJarFile = new File(driverDir, jarFileName) @@ -149,7 +147,14 @@ private[deploy] class DriverRunner( if (!localJarFile.exists()) { // May already exist if running multiple workers on one node logInfo(s"Copying user jar $jarPath to $destPath") - FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf) + Utils.fetchFile( + driverDesc.jarUrl, + driverDir, + conf, + securityManager, + hadoopConf, + System.currentTimeMillis(), + useCache = false) } if (!localJarFile.exists()) { // Verify copy succeeded diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index deef6ef9043c6..d1a12b01e78f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,10 +19,9 @@ package org.apache.spark.deploy.worker import java.io.File -import akka.actor._ - import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -39,9 +38,9 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader val userJarUrl = new File(userJar).toURI().toURL() @@ -58,7 +57,7 @@ object DriverWrapper { val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 83e24a7a1f80c..7d5acabb95a48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -50,7 +50,7 @@ private[deploy] class ExecutorRunner( val workerUrl: String, conf: SparkConf, val appLocalDirs: Seq[String], - var state: ExecutorState.Value) + @volatile var state: ExecutorState.Value) extends Logging { private val fullId = appId + "/" + execId diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c4c24a7866aa3..3ee2eb69e8a4e 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -436,7 +436,8 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl) + akkaUrl, + securityMgr) drivers(driverId) = driver driver.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index e0790274d7d3e..83fb991891a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,58 +17,63 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive: PartialFunction[Any, Unit] = { + case e => logWarning(s"Received unexpected message: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b5205d4e997ae..8300f9f2190b9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,39 +21,45 @@ import java.net.URL import java.nio.ByteBuffer import scala.collection.mutable -import scala.concurrent.Await +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } } def extractLogUrls: Map[String, String] = { @@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, userClassPath, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bf3135ef081c1..1b5fdeba28ee2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,14 +21,12 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} @@ -62,8 +60,6 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -88,9 +84,9 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst: Boolean = { @@ -116,6 +112,10 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("driver-heartbeater")) + startDriverHeartbeater() def launchTask( @@ -139,8 +139,9 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) - isStopped = true + env.rpcEnv.stop(executorEndpoint) + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() @@ -391,11 +392,8 @@ private[spark] class Executor( } } - private val timeout = AkkaUtils.lookupTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) private val heartbeatReceiverRef = - AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { @@ -426,8 +424,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() @@ -438,23 +435,17 @@ private[spark] class Executor( } /** - * Starts a thread to report heartbeat and partial metrics for active tasks to driver. - * This thread stops running when the executor is stopped. + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. */ private def startDriverHeartbeater(): Unit = { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val thread = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) - while (!isStopped) { - reportHeartBeat() - Thread.sleep(interval) - } - } + val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) } - thread.setDaemon(true) - thread.setName("driver-heartbeater") - thread.start() + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala similarity index 67% rename from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala rename to core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala index 3e47d13f7545d..cf362f8464735 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive} private[spark] case object TriggerThreadDump /** - * Actor that runs inside of executors to enable driver -> executor RPC. + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + context.reply(Utils.getThreadDump()) } } + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87c2aa481095d..818f7a4c8d422 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,15 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} private[spark] trait SparkHadoopMapRedUtil { @@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil { } } } + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 04eb2bf9ba4ab..6b898bd4bfc1b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -181,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 741fe3e1ea750..5a74c13b38bf7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -82,14 +82,15 @@ private[nio] class ConnectionManager( new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = - conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120)) + conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", + conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -988,6 +989,7 @@ private[nio] class ConnectionManager( def stop() { ackTimeoutMonitor.stop() + selector.wakeup() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac069..3406a7e97e368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc2..71578d1210fde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9059eb13bb5d8..c1d6971787572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c372845..0d130dd4c7a60 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -48,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -112,8 +113,11 @@ private[spark] object CheckpointRDD extends Logging { } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f6..7021a339e879b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -99,7 +99,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it @@ -120,7 +120,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = split.deps.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 5117ccfabfcc2..0c1b02c07d09f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 71e6e300fec5f..843a893235e56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.StatCounter class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Add up the elements in this RDD. */ def sum(): Double = { - self.reduce(_ + _) + self.fold(0.0)(_ + _) } /** @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6fdfdb734d1b8..6afe50161dacd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 6b4f097ea9ae5..05351ba4ff76b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = @@ -995,7 +995,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - try { + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1004,7 +1004,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) @@ -1068,7 +1068,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() var recordsWritten = 0L - try { + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1077,7 +1078,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ddbfd5624e741..d80d94a588346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -316,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -488,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -852,7 +852,7 @@ abstract class RDD[T: ClassTag]( * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -986,14 +986,14 @@ abstract class RDD[T: ClassTag]( combOp: (U, U) => U, depth: Int = 2): U = { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { + if (partitions.length == 0) { return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) } val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size + var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) { @@ -1026,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1061,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1140,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1243,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1489,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1499,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..1722c27e55003 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} +import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) @@ -92,12 +92,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD @@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -// Used for synchronization -private[spark] object RDDCheckpointData +private[spark] object RDDCheckpointData { + def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + } + + def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { + rddCheckpointDataPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index c27f435eb9c5a..e9d745588ee9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,7 +76,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4239e7e22af89..3986645350a82 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index d0be304762e1f..a96b6c3d23454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f2..523aaf2b860b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 0000000000000..f2c1c86af767e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. + */ +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) + + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. + */ + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. + */ + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. + */ + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. + */ + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String +} + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(self) + } + } +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] = + sendWithReply(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } + +} + +/** + * Represent a host with a port + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + fromURI(new java.net.URI(uri)) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 0000000000000..652e52f2b2e73 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error +import akka.pattern.{ask => akkaAsk} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private[akka] ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + assert(endpointRef != null) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + + case e: AssociationEvent => + // TODO ignore? + + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") + safelyCall(endpoint) { + processMessage(endpoint, m, sender) + } + + case AkkaFailure(e) => + safelyCall(endpoint) { + throw e + } + + case message: Any => { + logWarning(s"Unknown message: $message") + } + + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) + // Now actorRef can be created safely + endpointRef.init() + endpointRef + } + + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val needReply = m.needReply + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + _sender ! AkkaFailure(e) + } else { + throw e + } + } + } + + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + import actorSystem.dispatcher + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" +} + +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } +} + +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef(conf) with Logging { + + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } + + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }.mapTo[T] + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + +} + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param needReply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index b755d8fb15757..50a69379412d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.CallSite */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: Stage, + val finalStage: ResultStage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: CallSite, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b405bd3338e7c..4a32f8936fb0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,14 +23,11 @@ import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -53,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -83,7 +84,7 @@ class DAGScheduler( private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] // Stages we need to run whose parents aren't done @@ -114,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -131,8 +134,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val outputCommitCoordinator = env.outputCommitCoordinator - // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -150,7 +151,7 @@ class DAGScheduler( result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + taskMetrics: TaskMetrics): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } @@ -165,26 +166,23 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. - def executorLost(execId: String) { + def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - def executorAdded(execId: String, host: String) { + def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } @@ -210,40 +208,65 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { + private def getShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + jobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies registerShuffleDependencies(shuffleDep, jobId) // Then register current shuffleDep - val stage = - newOrUsedStage( - shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, - shuffleDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - + stage } } /** - * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation - * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage - * directly. + * Helper function to eliminate some code re-use when creating new stages. + */ + private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + (parentStages, id) + } + + /** + * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in + * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * Production of shuffle map stages should always use newOrUsedShuffleStage, not + * newShuffleMapStage directly. */ - private def newStage( + private def newShuffleMapStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_, _, _]], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: CallSite) - : Stage = - { - val parentStages = getParentStages(rdd, jobId) - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) + callSite: CallSite): ShuffleMapStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, + jobId, callSite, shuffleDep) + + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Create a ResultStage -- either directly for use as a result stage, or as part of the + * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated + * with the provided jobId. + */ + private def newResultStage( + rdd: RDD[_], + numTasks: Int, + jobId: Int, + callSite: CallSite): ResultStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) + stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -255,20 +278,17 @@ class DAGScheduler( * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ - private def newOrUsedStage( - rdd: RDD[_], - numTasks: Int, + private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, - callSite: CallSite) - : Stage = - { - val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val numTasks = rdd.partitions.size + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) } else { @@ -306,26 +326,23 @@ class DAGScheduler( } } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents.toList } - // Find ancestor missing shuffle dependencies and register into shuffleToMapStage - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (!parentsWithNoMapStage.isEmpty) { + while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = - newOrUsedStage( - currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, - currentShufDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(currentShufDep, jobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } - // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] @@ -351,7 +368,7 @@ class DAGScheduler( } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents @@ -382,7 +399,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList @@ -392,7 +409,7 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -412,7 +429,7 @@ class DAGScheduler( * * @param job The job whose state to cleanup. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) @@ -474,8 +491,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -504,15 +520,13 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => { + case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - } case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) @@ -526,9 +540,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -541,12 +553,12 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int) { + def cancelJob(jobId: Int): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId)) } - def cancelJobGroup(groupId: String) { + def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } @@ -554,7 +566,7 @@ class DAGScheduler( /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs() { + def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } @@ -633,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => @@ -675,7 +687,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -702,9 +714,10 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } @@ -722,13 +735,12 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { - var finalStage: Stage = null + properties: Properties) { + var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -773,7 +785,7 @@ class DAGScheduler( if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) - if (missing == Nil) { + if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { @@ -794,22 +806,19 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() + // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { - if (stage.isShuffleMap) { - (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) - } else { - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + stage match { + case stage: ShuffleMapStage => + (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + case stage: ResultStage => + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } + val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -830,18 +839,21 @@ class DAGScheduler( try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = - if (stage.isShuffleMap) { - closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() - } else { - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() - } + val taskBinaryBytes: Array[Byte] = stage match { + case stage: ShuffleMapStage => + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + case stage: ResultStage => + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage + + // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") @@ -849,20 +861,22 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } - } else { - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + val tasks: Seq[Task[_]] = stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } + + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } } if (tasks.size > 0) { @@ -873,13 +887,20 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - outputCommitCoordinator.stageEnd(stage.id) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - runningStages -= stage + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + val debugString = stage match { + case stage: ShuffleMapStage => + s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})" + case stage : ResultStage => + s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + } + logDebug(debugString) } } @@ -945,22 +966,6 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTimeMillis()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -968,7 +973,10 @@ class DAGScheduler( stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - stage.resultOfJob match { + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -976,7 +984,7 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - markStageAsFinished(stage) + markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) @@ -988,7 +996,7 @@ class DAGScheduler( job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => - // TODO: Perhaps we want to mark the stage as failed? + // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } } @@ -997,6 +1005,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId @@ -1004,50 +1013,54 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partitionId, status) + shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { - markStageAsFinished(stage) + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - if (stage.shuffleDep.isDefined) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } + + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + shuffleStage.shuffleDep.shuffleId, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = true) + clearCacheLocs() - if (stage.outputLocs.exists(_ == Nil)) { - // Some tasks had failed; let's resubmit this stage + if (shuffleStage.outputLocs.contains(Nil)) { + // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) + .map(_._2).mkString(", ")) + submitStage(shuffleStage) } else { val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + for (shuffleStage <- waitingStages) { + logInfo("Missing parents for " + shuffleStage + ": " + + getMissingParentStages(shuffleStage)) } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage + for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) + { + newlyRunnable += shuffleStage } waitingStages --= newlyRunnable runningStages ++= newlyRunnable for { - stage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(stage) + shuffleStage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(shuffleStage) } { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage, jobId) + logInfo("Submitting " + shuffleStage + " (" + + shuffleStage.rdd + "), which is now runnable") + submitMissingTasks(shuffleStage, jobId) } } } @@ -1068,7 +1081,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1184,6 +1196,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1204,9 +1236,7 @@ class DAGScheduler( } } - /** - * Fails a job and all stages that are only used by that job, and cleans up relevant state. - */ + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1235,8 +1265,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1254,9 +1283,7 @@ class DAGScheduler( } } - /** - * Return true if one of stage's ancestors is target. - */ + /** Return true if one of stage's ancestors is target. */ private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true @@ -1282,7 +1309,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } visitedRdds.contains(target.rdd) @@ -1312,9 +1339,7 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]) - : Seq[TaskLocation] = - { + visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd,partition))) { @@ -1323,12 +1348,12 @@ class DAGScheduler( } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { + if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { + if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep @@ -1412,7 +1437,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.sc.stop() } - override def onStop() { + override def onStop(): Unit = { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index c0d889360ae99..08e7727db2fde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,21 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { Some(CompressionCodec.createCodec(sparkConf)) @@ -259,13 +259,13 @@ private[spark] object EventLoggingListener extends Logging { * @return A path which consists of file-system-safe characters. */ def getLogPath( - logBaseDir: String, + logBaseDir: URI, appId: String, compressionCodecName: Option[String] = None): String = { val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase // e.g. app_123, app_123.lzf val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") - Utils.resolveURI(logBaseDir).toString.stripSuffix("/") + "/" + logName + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index a3caa9f000c89..7c184b1dcb308 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable -import akka.actor.{ActorRef, Actor} - import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, ActorLogReceive} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -34,8 +32,8 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * policy. * * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is - * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit - * output will be forwarded to the driver's OutputCommitCoordinator. + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. * * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. @@ -43,10 +41,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { // Initialized by SparkEnv - var coordinatorActor: Option[ActorRef] = None - private val timeout = AkkaUtils.askTimeout(conf) - private val maxAttempts = AkkaUtils.numRetries(conf) - private val retryInterval = AkkaUtils.retryWaitMs(conf) + var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int private type PartitionId = Long @@ -64,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + /** * Called by tasks to ask whether they can commit their output to HDFS. * @@ -81,9 +83,9 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { partition: PartitionId, attempt: TaskAttemptId): Boolean = { val msg = AskPermissionToCommitOutput(stage, partition, attempt) - coordinatorActor match { - case Some(actor) => - AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithReply[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") @@ -118,15 +120,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { logInfo( s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") case otherReason => - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") - authorizedCommitters.remove(partition) + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } } } def stop(): Unit = synchronized { - coordinatorActor.foreach(_ ! StopCoordinator) - coordinatorActor = None + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None authorizedCommittersByStage.clear() } @@ -157,16 +161,20 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private[spark] object OutputCommitCoordinator { // This actor is used only for RPC - class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) - extends Actor with ActorLogReceive with Logging { + private[spark] class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + override def receive: PartialFunction[Any, Unit] = { case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - context.stop(self) - sender ! true + stop() + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + context.reply( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala new file mode 100644 index 0000000000000..c0f3d5a13d623 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite + +/** + * The ResultStage represents the final stage in a job. + */ +private[spark] class ResultStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + // The active job for this result stage. Will be empty if the job has already finished + // (e.g., because the job was cancelled). + var resultOfJob: Option[ActiveJob] = None + + override def toString: String = "ResultStage " + id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala new file mode 100644 index 0000000000000..d02210743484c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.ShuffleDependency +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite + +/** + * The ShuffleMapStage represents the intermediate stages in a job. + */ +private[spark] class ShuffleMapStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite, + val shuffleDep: ShuffleDependency[_, _, _]) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + override def toString: String = "ShuffleMapStage " + id + + var numAvailableOutputs: Long = 0 + + def isAvailable: Boolean = numAvailableOutputs == numPartitions + + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + + def addOutputLoc(partition: Int, status: MapStatus): Unit = { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) { + numAvailableOutputs += 1 + } + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fd0d484b45460..6c7d00069acb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,7 +33,7 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 4cbc6e84a6bdd..5d0ddb8377c33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -47,29 +46,23 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] class Stage( +private[spark] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, val callSite: CallSite) extends Logging { - val isShuffleMap = shuffleDep.isDefined val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ - var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] - private var nextAttemptId = 0 + private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm @@ -77,53 +70,6 @@ private[spark] class Stage( /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ var latestInfo: StageInfo = StageInfo.fromStage(this) - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId @@ -133,11 +79,8 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString: String = "Stage " + id - - override def hashCode(): Int = id - - override def equals(other: Any): Boolean = other match { + override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4d9f940813b8e..8b592867ee31d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 076b36e86c0ce..13a52d836f32f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -62,10 +62,10 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) + val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) + val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") // CPUs to request per task val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) @@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, + SPECULATION_INTERVAL_MS milliseconds) { Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - } + }(sc.env.actorSystem.dispatcher) } } @@ -173,7 +172,7 @@ private[spark] class TaskSchedulerImpl( this.cancel() } } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) } hasReceivedTask = true } @@ -394,7 +393,7 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { @@ -407,8 +406,7 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d509881c74fef..7dc325283d961 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -848,15 +848,18 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - conf.get("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - conf.get("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - conf.get("spark.locality.wait.rack", defaultWait).toLong - case _ => 0L + val defaultWait = conf.get("spark.locality.wait", "3s") + val localityWaitKey = level match { + case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" + case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" + case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" + case _ => null + } + + if (localityWaitKey != null) { + conf.getTimeAsMs(localityWaitKey, defaultWait) + } else { + 0L } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be198d..70364cea62a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( executorId: String, + executorRef: RpcEndpointRef, hostPort: String, cores: Int, logUrls: Map[String, String]) @@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5d258d9da4d1a..63987dfb32695 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -57,8 +52,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) + val maxRegisteredWaitingTimeMs = + conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") val createTime = System.currentTimeMillis() private val executorDataMap = new HashMap[String, ExecutorData] @@ -71,48 +66,27 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } + private val addressToExecutorId = new HashMap[RpcAddress, String] - def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + private val reviveThread = + Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + override def onStart() { + // Periodically revive offers to allow delay scheduling to work + val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") + + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveIntervalMs, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -133,33 +107,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -169,6 +168,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -199,7 +203,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -223,9 +227,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -236,16 +244,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -256,22 +263,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -281,11 +287,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -297,9 +302,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } - if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { + if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTimeMs) { logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTimeMs(ms)") return true } false @@ -391,5 +396,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe76720..26e72c0bff38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a59524e7..0324c9dab910b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ffd4825705755..ccf1dc5af6120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", @@ -84,12 +82,11 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec) - + val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, + command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() - waitForRegistration() } @@ -121,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5a38ad9f2b12c..f72566c370a6f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private val askAmThreadPool = + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e13de0f46ef89..b037a4966ced0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index eb3f999b5b375..50ba0b9d5a612 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,17 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.{Executors, TimeUnit} -import scala.concurrent.duration._ -import scala.language.postfixOps - -import akka.actor.{Actor, ActorRef, Props} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils private case class ReviveOffers() @@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { - import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private val reviveThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("local-revive-thread")) private var freeCores = totalCores @@ -59,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -72,11 +71,15 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } + def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) val tasks = scheduler.resourceOffers(offers).flatten @@ -87,9 +90,17 @@ private[spark] class LocalActor( } if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) } } + + override def onStop(): Unit = { + reviveThread.shutdownNow() + } } /** @@ -97,35 +108,37 @@ private[spark] class LocalActor( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.sendWithReply(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index f83bcaa5cc09e..579fb6624e692 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,10 +49,20 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) + if (bufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + + s"2048 mb, got: + $bufferSizeMb mb.") + } + private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + + val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + if (maxBufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index d0178dfde6935..5be3ed771e534 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup { // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { + extends ShuffleBlockResolver with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -175,11 +175,6 @@ class FileShuffleBlockManager(conf: SparkConf) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 87fd161e06c85..a1741e2875c16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -26,6 +26,9 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import IndexShuffleBlockManager.NOOP_REDUCE_ID /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -39,25 +42,18 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { private lazy val blockManager = SparkEnv.get.blockManager private val transportConf = SparkTransportConf.fromSparkConf(conf) - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** @@ -83,24 +79,19 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { + Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L out.writeLong(offset) - for (length <- lengths) { offset += length out.writeLong(offset) } - } finally { + } { out.close() } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index @@ -123,3 +114,11 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { override def stop(): Unit = {} } + +private[spark] object IndexShuffleBlockManager { + // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index b521f0c7fc77e..4342b0d598b16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] -trait ShuffleBlockManager { +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { type ShuffleId = Int /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e1249256..978366d1a1d1b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9be..f6e6fe5defe09 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ + /** Write a sequence of records to this task's output */ def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b34400..2a7df8dd5bd83 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { + override def shuffleBlockResolver: FileShuffleBlockManager = { fileShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808e..0497036192154 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { + override def shuffleBlockResolver: IndexShuffleBlockManager = { indexShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 55ea0f17b156a..a066435df6fb0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -58,8 +58,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) + sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) sorter.insertAll(records) } @@ -67,7 +66,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) @@ -100,3 +99,4 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dff09a75d038..1aa0ef18de118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +186,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +202,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +265,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -301,7 +301,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -439,14 +439,10 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockManager = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } @@ -1219,7 +1215,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index a6f1ebf325a7c..69ac37511e730 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f08b..ceacf043029f3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" val timeout = AkkaUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -67,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -85,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5b5328016124e..28c73a7d543ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,68 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + context.reply(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + context.reply(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + context.reply(getPeers(blockManagerId)) - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + context.reply(memoryStatus) case GetStorageStatus => - sender ! storageStatus + context.reply(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + context.reply(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + context.reply(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + context.reply(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + context.reply(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + context.reply(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + context.reply(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + context.reply(true) case StopBlockManagerMaster => - sender ! true - context.stop(self) + context.reply(true) + stop() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + context.reply(heartbeatReceived(blockManagerId)) - case other => - logWarning("Got unknown message: " + other) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -129,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher + // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) }.toSeq ) } @@ -155,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } @@ -217,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) } } } @@ -247,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* - * Rather than blocking on the block status query, master actor should simply return + * Rather than blocking on the block status query, master endpoint should simply return * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. + * that is also waiting for this master endpoint's response to a previous message. */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -276,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -291,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -308,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -379,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) } } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } } @DeveloperApi @@ -412,7 +409,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveEndpoint: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453edef0..f89d8d7493f7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 52fb896c4e21f..8980fa8eb70e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -17,41 +17,43 @@ package org.apache.spark.storage -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} +import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive /** - * An actor to take commands from the master to execute options. For example, + * An RpcEndpoint to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor( +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { - import context.dispatcher + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { + doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { + doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -59,30 +61,34 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { + doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + context.reply(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + context.sendFailure(t) } } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index f703e50b6b0ac..0dfc91dfaff85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -23,6 +23,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -140,14 +141,17 @@ private[spark] class DiskBlockObjectWriter( override def close() { if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - callWithTiming { - fos.getFD.sync() + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + callWithTiming { + fos.getFD.sync() + } } + } { + objOut.close() } - objOut.close() channel = null bs = null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff168791..4b232ae7d3180 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -46,10 +46,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +78,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +109,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +125,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 0186eb30a1905..034525b56f59c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 67f572e79314d..77c0bc8b5360a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -65,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -81,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index ea548f23120d9..f9860d1a5ce76 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -48,7 +48,7 @@ private[spark] abstract class WebUI( protected val handlers = ArrayBuffer[ServletContextHandler]() protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostName() + protected val localHostName = Utils.localHostNameForURI() protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0c..81a7cbde01ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + override def isDefinedAt(o: Any): Boolean = { + val handled = _receiveWithLogging.isDefinedAt(o) + if (!handled) { + log.debug(s"Received unexpected actor system event: $o") + } + handled + } override def apply(o: Any): Unit = { if (log.isDebugEnabled) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 48a6ede05e17b..8e8cc7cc6389e 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import scala.util.Try import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -66,7 +65,8 @@ private[spark] object AkkaUtils extends Logging { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 120)) + val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", + conf.get("spark.network.timeout", "120s")) val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" @@ -78,8 +78,8 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) - val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s") + val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s") val secretKey = securityManager.getSecretKey() val isAuthOn = securityManager.isAuthenticationEnabled() @@ -102,14 +102,14 @@ private[spark] object AkkaUtils extends Logging { |akka.jvm-exit-on-fatal-error = off |akka.remote.require-cookie = "$requireCookie" |akka.remote.secure-cookie = "$secureCookie" - |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s - |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s + |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s + |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s |akka.actor.provider = "akka.remote.RemoteActorRefProvider" |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" |akka.remote.netty.tcp.hostname = "$host" |akka.remote.netty.tcp.port = $port |akka.remote.netty.tcp.tcp-nodelay = on - |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s + |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize @@ -179,7 +179,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b84424..e9b2b8d24b476 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index d60b8b9a31a9b..a725767d08cc2 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,9 +19,12 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging +import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -64,4 +67,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ def onPostEvent(listener: L, event: E): Unit + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { + val c = implicitly[ClassTag[T]].runtimeClass + listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + } + } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 375ed430bde45..2bbfc988a99a8 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -76,7 +76,7 @@ private[spark] object MetadataCleanerType extends Enumeration { // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { def getDelaySeconds(conf: SparkConf): Int = { - conf.getInt("spark.cleaner.ttl", -1) + conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt } def getDelaySeconds( diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala new file mode 100644 index 0000000000000..6665b17c3d5df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} + +object RpcUtils { + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0b5a914e7dbbf..1029b0f9fce1e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,7 +22,7 @@ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.util.{Properties, Locale, Random, UUID} -import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConversions._ @@ -34,6 +34,7 @@ import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} +import com.google.common.net.InetAddresses import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -46,6 +47,7 @@ import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -313,7 +315,7 @@ private[spark] object Utils extends Logging { transferToEnabled: Boolean = false): Long = { var count = 0L - try { + tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. @@ -353,7 +355,7 @@ private[spark] object Utils extends Logging { } } count - } finally { + } { if (closeStreams) { try { in.close() @@ -611,9 +613,10 @@ private[spark] object Utils extends Logging { } Utils.setupSecureURLConnection(uc, securityMgr) - val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 - uc.setConnectTimeout(timeout) - uc.setReadTimeout(timeout) + val timeoutMs = + conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 + uc.setConnectTimeout(timeoutMs) + uc.setReadTimeout(timeoutMs) uc.connect() val in = uc.getInputStream() downloadFile(url, in, targetFile, fileOverwrite) @@ -789,13 +792,12 @@ private[spark] object Utils extends Logging { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). * Note, this is typically not used from within core spark. */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + private lazy val localIpAddress: InetAddress = findLocalInetAddress() - private def findLocalIpAddress(): String = { + private def findLocalInetAddress(): InetAddress = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") if (defaultIpOverride != null) { - defaultIpOverride + InetAddress.getByName(defaultIpOverride) } else { val address = InetAddress.getLocalHost if (address.isLoopbackAddress) { @@ -806,15 +808,20 @@ private[spark] object Utils extends Logging { // It's more proper to pick ip address following system output order. val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + val addresses = ni.getInetAddresses.toList + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") + " a loopback address: " + address.getHostAddress + "; using " + + strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress + return strippedAddress } } logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + @@ -822,7 +829,7 @@ private[spark] object Utils extends Logging { " external IP address!") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") } - address.getHostAddress + address } } @@ -842,11 +849,14 @@ private[spark] object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) + customHostname.getOrElse(localIpAddress.getHostAddress) } - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName + /** + * Get the local machine's URI. + */ + def localHostNameForURI(): String = { + customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } def checkHost(host: String, message: String = "") { @@ -1010,6 +1020,22 @@ private[spark] object Utils extends Logging { ) } + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + def timeStringAsMs(str: String): Long = { + JavaUtils.timeStringAsMs(str) + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + def timeStringAsSeconds(str: String): Long = { + JavaUtils.timeStringAsSec(str) + } + /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. */ @@ -1214,6 +1240,54 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + // It would be nice to find a method on Try that did this + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + // We could do originalThrowable.addSuppressed(t), but it's + // not available in JDK 1.6. + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when @@ -2055,7 +2129,7 @@ private[spark] object Utils extends Logging { */ def getCurrentUserName(): String = { Option(System.getenv("SPARK_USER")) - .getOrElse(UserGroupInformation.getCurrentUser().getUserName()) + .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } } @@ -2074,7 +2148,7 @@ private[spark] class RedirectThread( override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - try { + Utils.tryWithSafeFinally { val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { @@ -2082,7 +2156,7 @@ private[spark] class RedirectThread( out.flush() len = in.read(buf) } - } finally { + } { if (propagateEof) { out.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index f79e8e0491ea1..41cb8cfe2afa3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b962c101c91da..035f3767ff554 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -664,6 +664,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** + * Exposed for testing purposes. + * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -673,7 +675,7 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { @@ -726,25 +728,19 @@ private[spark] class ExternalSorter[K, V, C]( // this simple we spill out the current in-memory collection so that everything is in files. spillToPartitionFiles(if (aggregator.isDefined) map else buffer) partitionWriters.foreach(_.commitAndClose()) - var out: FileOutputStream = null - var in: FileInputStream = null + val out = new FileOutputStream(outputFile, true) val writeStartTime = System.nanoTime - try { - out = new FileOutputStream(outputFile, true) + util.Utils.tryWithSafeFinally { for (i <- 0 until numPartitions) { - in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - in.close() - in = null - lengths(i) = size - } - } finally { - if (out != null) { - out.close() - } - if (in != null) { - in.close() + val in = new FileInputStream(partitionWriters(i).fileSegment().file) + util.Utils.tryWithSafeFinally { + lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) + } { + in.close() + } } + } { + out.close() context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } @@ -781,7 +777,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Read a partition file back as an iterator (used in our iterator method) */ - def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { if (writer.isOpen) { writer.commitAndClose() } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index bd0f8bdefa171..75399461f2a5f 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -27,19 +27,20 @@ import org.scalatest.Matchers class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { - implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { - def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[A]) : mutable.Set[A] = { - new mutable.HashSet[A]() + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = + new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } } - } test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -49,11 +50,10 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { d.foreach{x => acc += x} acc.value should be (210) - - val longAcc = sc.accumulator(0l) + val longAcc = sc.accumulator(0L) val maxInt = Integer.MAX_VALUE.toLong d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) + longAcc.value should be (210L + maxInt * 20) } test ("value not assignable from tasks") { diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 4b25c200a695a..70529d9216591 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -45,16 +45,17 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf rdd = new RDD[Int](sc, Nil) { override def getPartitions: Array[Partition] = Array(split) override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator + override def compute(split: Partition, context: TaskContext): Iterator[Int] = + Array(1, 2, 3, 4).iterator } rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 32abc65385267..e1faddeabec79 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -75,7 +75,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.partitions.toList === + parCollection.checkpointData.get.getPartitions.toList) assert(parCollection.collect() === result) } @@ -102,13 +103,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(_.union(otherRDD)) testRDDPartitions(_.union(otherRDD)) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(new CartesianRDD(sc, _, otherRDD)) testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) @@ -223,7 +224,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( - partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass, + partitionBeforeCheckpoint.parents.head.getClass != + partitionAfterCheckpoint.parents.head.getClass, "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" ) } @@ -358,7 +360,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD() = { + def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -445,7 +447,8 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index cdfaacee7da40..097e7076e5391 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.storage._ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -64,7 +65,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha } } - //------ Helper functions ------ + // ------ Helper functions ------ protected def newRDD() = sc.makeRDD(1 to 10) protected def newPairRDD() = newRDD().map(_ -> 1) @@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { postGCTester.assertCleanup() } + test("automatically cleanup checkpoint") { + val checkpointDir = java.io.File.createTempFile("temp", "") + checkpointDir.deleteOnExit() + checkpointDir.delete() + var rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + var rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined) + val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get + val fs = path.getFileSystem(sc.hadoopConfiguration) + assert(fs.exists(path)) + + // the checkpoint is not cleaned by default (without the configuration set) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + sc.stop() + val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint"). + set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") + sc = new SparkContext(conf) + rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + // Test that GC causes checkpoint data cleanup after dereferencing the RDD + postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + } + test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly @@ -370,7 +417,7 @@ class CleanerTester( val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { toBeCleanedRDDIds -= rddId - logInfo("RDD "+ rddId + " cleaned") + logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index abfcee75728dc..22acc270b983e 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,10 +28,20 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("local") @@ -39,25 +49,20 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") val sc0 = new SparkContext(conf) + contexts += sc0 assert(sc0.executorAllocationManager.isDefined) sc0.stop() // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") - intercept[SparkException] { new SparkContext(conf1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() + intercept[SparkException] { contexts += new SparkContext(conf1) } // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") - intercept[SparkException] { new SparkContext(conf2) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() + intercept[SparkException] { contexts += new SparkContext(conf2) } // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) @@ -665,16 +670,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -683,14 +678,28 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) - .set("spark.dynamicAllocation.schedulerBacklogTimeout", schedulerBacklogTimeout.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", + s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", - sustainedSchedulerBacklogTimeout.toString) - .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) + s"${sustainedSchedulerBacklogTimeout.toString}s") + .set("spark.dynamicAllocation.executorIdleTimeout", s"${executorIdleTimeout.toString}s") .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 5fdf6bc2777e3..a69e9b761f9a7 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io._ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} -import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLException import com.google.common.io.ByteStreams import org.apache.commons.io.{FileUtils, IOUtils} @@ -228,7 +228,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { try { server.initialize() - intercept[SSLHandshakeException] { + intercept[SSLException] { fileTransferTest(server) } } finally { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 7acd27c735727..c8f08eed47c76 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -222,7 +222,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) nums.saveAsSequenceFile(outputDir) val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } @@ -451,7 +451,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) } @@ -459,8 +460,10 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the non-empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) @@ -471,16 +474,20 @@ class FileSuite extends FunSuite with LocalSparkContext { val sf = new SparkConf() sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") sc = new SparkContext(sf) - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) } test ("save Hadoop Dataset through old Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new JobConf() job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) @@ -492,7 +499,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("save Hadoop Dataset through new Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new Job(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000000000..0fd570e5297d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index d895230ecf330..51348c039b5c9 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -51,7 +51,7 @@ private object ImplicitOrderingSuite { override def compare(o: OrderedClass): Int = ??? } - def basicMapExpectations(rdd: RDD[Int]) = { + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), (rdd.map(x => (1, x)).keyOrdering.isDefined, @@ -68,7 +68,7 @@ private object ImplicitOrderingSuite { "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - def otherRDDMethodExpectations(rdd: RDD[Int]) = { + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, @@ -82,4 +82,4 @@ private object ImplicitOrderingSuite { (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined, "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined")) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 21487bc24d58a..4d3e09793faff 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -188,7 +188,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter val rdd = sc.parallelize(1 to 10, 2).map { i => JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) - }.reduceByKey(_+_) + }.reduceByKey(_ + _) val f1 = rdd.collectAsync() val f2 = rdd.countAsync() diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 53e367a61715b..8bf2e55defd02 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -37,7 +37,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self super.afterEach() } - def resetSparkContext() = { + def resetSparkContext(): Unit = { LocalSparkContext.stop(sc) sc = null } @@ -54,7 +54,7 @@ object LocalSparkContext { } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { try { f(sc) } finally { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ccfe0678cb1c3..6295d34be5ca9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,37 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, isA} import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, + securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { + RpcEnv.create(name, host, port, conf, securityManager) + } + test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,8 +147,8 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch below akka frame size") { @@ -157,19 +157,24 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("spark") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + verify(rpcCallContext).reply(any()) + verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch exceeds akka frame size") { @@ -178,12 +183,11 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("test") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) @@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + verify(rpcCallContext, never()).reply(any()) + verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index b7532314ada01..47e3bf6e1ac41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -92,7 +92,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row) = x.value - y.value + override def compare(x: Row, y: Row): Int = x.value - y.value } val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) @@ -212,20 +212,24 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + def verify(testFun: => Unit): Unit = { + intercept[SparkException](testFun).getMessage.contains("array") + } + + verify(arrs.distinct()) // We can't catch all usages of arrays, since they might occur inside other collections: // assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + verify(arrPairs.partitionBy(new HashPartitioner(2))) + verify(arrPairs.join(arrPairs)) + verify(arrPairs.leftOuterJoin(arrPairs)) + verify(arrPairs.rightOuterJoin(arrPairs)) + verify(arrPairs.fullOuterJoin(arrPairs)) + verify(arrPairs.groupByKey()) + verify(arrPairs.countByKey()) + verify(arrPairs.countByKeyApprox(1)) + verify(arrPairs.cogroup(arrPairs)) + verify(arrPairs.reduceByKeyLocally(_ + _)) + verify(arrPairs.reduceByKey(_ + _)) } test("zero-length partitions should be correctly handled") { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 444a33371bd71..93f46ef11c0e2 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -36,7 +36,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,7 +53,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test resolving property with defaults specified ") { @@ -66,7 +68,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) @@ -83,7 +86,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test whether defaults can be overridden ") { @@ -99,7 +103,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index ace8123a8961f..308b9ea17708d 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -21,10 +21,11 @@ import java.io.File object SSLSampleConfigs { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File(this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath + val untrustedKeyStorePath = new File( + this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - def sparkSSLConfig() = { + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -38,7 +39,7 @@ object SSLSampleConfigs { conf } - def sparkSSLConfigUntrusted() = { + def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index f57921b768310..d7180516029d5 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -142,7 +142,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new ShuffledRDD[Int, Int, Int](pairs, @@ -155,7 +155,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) @@ -169,7 +169,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -196,7 +196,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex shuffleSpillCompress <- Set(true, false); shuffleCompress <- Set(true, false) ) { - val conf = new SparkConf() + val myConf = conf.clone() .setAppName("test") .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() - sc = new SparkContext(conf) + sc = new SparkContext(myConf) try { sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() } catch { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b07c4d93db4e6..94be1c6d6397c 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files @@ -25,9 +26,11 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable - import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -111,11 +114,13 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { if (length1 != gotten1.length()) { throw new SparkException( - s"file has different length $length1 than added file ${gotten1.length()} : " + absolutePath1) + s"file has different length $length1 than added file ${gotten1.length()} : " + + absolutePath1) } if (length2 != gotten2.length()) { throw new SparkException( - s"file has different length $length2 than added file ${gotten2.length()} : " + absolutePath2) + s"file has different length $length2 than added file ${gotten2.length()} : " + + absolutePath2) } if (absolutePath1 == gotten1.getAbsolutePath) { @@ -173,4 +178,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 41d6ea29d5b06..084eb237d70d1 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -82,7 +82,8 @@ class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be ( + Set(firstJobId, secondJobId)) } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index af3272692d7a1..c8fdfa693912e 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -33,7 +33,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { val broadcast = rdd.context.broadcast(list) val bid = broadcast.id - def doSomething() = { + def doSomething(): Set[(Int, Boolean)] = { rdd.map { x => val bm = SparkEnv.get.blockManager // Check if broadcast block was fetched diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 518073dcbb64e..745f9eeee7536 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -46,5 +46,4 @@ class ClientSuite extends FunSuite with Matchers { // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } - } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 68b5776fc6515..b58d62567afe1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} class JsonProtocolSuite extends FunSuite { @@ -100,13 +100,13 @@ class JsonProtocolSuite extends FunSuite { appInfo } - def createDriverCommand() = new Command( + def createDriverCommand(): Command = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) - def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, - false, createDriverCommand()) + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) @@ -124,8 +124,9 @@ class JsonProtocolSuite extends FunSuite { } def createDriverRunner(): DriverRunner = { - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 54dd7c9c45c61..c93d16f8a1586 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.net.URL +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source @@ -56,7 +57,7 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS else super.getenv(name) } @@ -65,16 +66,17 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { new MySparkConf().setAll(getAll) } } - val conf = new MySparkConf() + val conf = new MySparkConf().set( + "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - val listener = new SaveExecutorInfo - sc.addSparkListener(listener) - // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) info.logUrlMap.values.foreach { logUrl => @@ -82,12 +84,12 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { } } } +} - private class SaveExecutorInfo extends SparkListener { - val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor: SparkListenerExecutorAdded) { - addedExecutorInfos(executor.executorId) = executor.executorInfo - } + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index e908ba604ebed..fcae603c7d18e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -50,7 +50,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers inProgress: Boolean, codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, appId) + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) val logPath = new URI(logUri).getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3a9963a5ce7b7..20de46fdab909 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -42,10 +42,10 @@ class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar { when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) val page = new HistoryPage(historyServer) - //when + // when val response = page.render(request) - //then + // then val links = response \\ "a" val justHrefs = for { l <- links diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 2fa90e3bd1c63..8e09976636386 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -508,7 +508,7 @@ private class DummyMaster( exception: Option[Exception] = None) extends Actor { - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) case RequestKillDriver(driverId) => @@ -531,7 +531,7 @@ private class SmarterMaster extends Actor { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 1d64ec201e647..61071ee17256c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -129,7 +129,8 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.sparkProperties("spark.files") === "fireball.png") assert(newMessage.sparkProperties("spark.driver.memory") === "512m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") - assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === + " -Dslices=5 -Dcolor=mostly_red") assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") assert(newMessage.sparkProperties("spark.driver.supervise") === "false") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index aa6e4874cecde..2159fd8c16c6f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -25,7 +25,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.util.Clock @@ -33,8 +33,9 @@ class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - driverDescription, null, "akka://1.2.3.4/worker/") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + driverDescription, null, "akka://1.2.3.4/worker/", new SecurityManager(conf)) } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6fca6321e5a1b..a8b9df227c996 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -35,7 +35,8 @@ class ExecutorRunnerTest extends FunSuite { val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) + val builder = CommandUtils.buildProcessBuilder( + appDesc.command, 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 372d7aa453008..7cc2104281464 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -37,7 +37,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "50000" else super.getenv(name) } @@ -56,7 +56,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "5G" else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 84e2fd7ad936d..450fba21f4b5c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.{Matchers, FunSuite} class WorkerSuite extends FunSuite with Matchers { - def cmd(javaOpts: String*) = Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) - def conf(opts: (String, String)*) = new SparkConf(loadDefaults = false).setAll(opts) + def cmd(javaOpts: String*): Command = { + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + } + def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a1..6a6f29dd613cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,38 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + assert(workerWatcher.isShutDown) + rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + assert(!workerWatcher.isShutDown) + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 78fa98a3b9065..190b08d950a02 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -238,7 +238,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext sc.textFile(tmpFilePath, 4) .map(key => (key, 1)) - .reduceByKey(_+_) + .reduceByKey(_ + _) .saveAsTextFile("file://" + tmpFile.getAbsolutePath) sc.listenerBus.waitUntilEmpty(500) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 37e528435aa5d..100ac77dec1f7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -35,7 +35,8 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { val property = conf.getInstance("random") assert(property.size() === 2) - assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(property.getProperty("sink.servlet.path") === "/metrics/json") } @@ -47,16 +48,20 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(masterProp.size() === 5) assert(masterProp.getProperty("sink.console.period") === "20") assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") val workerProp = conf.getInstance("worker") assert(workerProp.size() === 5) assert(workerProp.getProperty("sink.console.period") === "10") assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 716f875d30b8a..02424c59d6831 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -260,8 +260,8 @@ class ConnectionManagerSuite extends FunSuite { test("sendMessageReliably timeout") { val clientConf = new SparkConf clientConf.set("spark.authenticate", "false") - val ackTimeout = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + val ackTimeoutS = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") val clientSecurityManager = new SecurityManager(clientConf) val manager = new ConnectionManager(0, clientConf, clientSecurityManager) @@ -272,7 +272,7 @@ class ConnectionManagerSuite extends FunSuite { val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeout * 3 * 1000) + Thread.sleep(ackTimeoutS * 3 * 1000) None }) @@ -287,7 +287,7 @@ class ConnectionManagerSuite extends FunSuite { // Otherwise TimeoutExcepton is thrown from Await.result. // We expect TimeoutException is not thrown. intercept[IOException] { - Await.result(future, (ackTimeout * 2) second) + Await.result(future, (ackTimeoutS * 2) second) } manager.stop() diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 97079382c716f..01039b9449daf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -22,6 +22,12 @@ import org.scalatest.FunSuite import org.apache.spark._ class DoubleRDDSuite extends FunSuite with SharedSparkContext { + test("sum") { + assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) + assert(sc.parallelize(Seq(1.0)).sum() === 1.0) + assert(sc.parallelize(Seq(1.0, 2.0)).sum() === 3.0) + } + // Verify tests on the histogram functionality. We test with both evenly // and non-evenly spaced buckets as the bucket lookup function changes. test("WorksOnEmpty") { diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 0dc59888f7304..be8467354b222 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -80,7 +80,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) + assert(rdd.reduce(_ + _) === 10100) } test("large id overflow") { @@ -92,7 +92,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { 1131544775L, 567279358897692673L, 20, (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 5050) + assert(rdd.reduce(_ + _) === 5050) } after { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 108f70af43f37..ca0d953d306d8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,13 +168,13 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() + val sums = pairs.reduceByKey(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } test("reduceByKey with collectAsMap") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() + val sums = pairs.reduceByKey(_ + _).collectAsMap() assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) @@ -182,7 +182,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey with many output partitons") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() + val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -192,7 +192,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def getPartition(key: Any) = key.asInstanceOf[Int] } val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) + val sums = pairs.reduceByKey(_ + _) assert(sums.collect().toSet === Set((1, 4), (0, 1))) assert(sums.partitioner === Some(p)) // count the dependencies to make sure there is only 1 ShuffledRDD @@ -208,7 +208,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("countApproxDistinctByKey") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is * only a statistical bound, the tests can fail for large values of relativeSD. We will be using @@ -465,7 +465,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("foldByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() + val sums = pairs.foldByKey(0)(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -505,7 +505,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { conf.setOutputCommitter(classOf[FakeOutputCommitter]) FakeOutputCommitter.ran = false - pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } @@ -552,7 +553,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } private object StratifiedAuxiliary { - def stratifier (fractionPositive: Double) = { + def stratifier (fractionPositive: Double): (Int) => String = { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } @@ -572,7 +573,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, true, samplingRate, seed, n) testPoisson(stratifiedData, true, samplingRate, seed, n) } @@ -580,7 +581,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSample(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, false, samplingRate, seed, n) testPoisson(stratifiedData, false, samplingRate, seed, n) } @@ -590,7 +591,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey() .mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -612,7 +613,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -701,27 +702,27 @@ class FakeOutputFormat() extends OutputFormat[Integer, Integer]() { */ class NewFakeWriter extends NewRecordWriter[Integer, Integer] { - def close(p1: NewTaskAttempContext) = () + def close(p1: NewTaskAttempContext): Unit = () - def write(p1: Integer, p2: Integer) = () + def write(p1: Integer, p2: Integer): Unit = () } class NewFakeCommitter extends NewOutputCommitter { - def setupJob(p1: NewJobContext) = () + def setupJob(p1: NewJobContext): Unit = () def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false - def setupTask(p1: NewTaskAttempContext) = () + def setupTask(p1: NewTaskAttempContext): Unit = () - def commitTask(p1: NewTaskAttempContext) = () + def commitTask(p1: NewTaskAttempContext): Unit = () - def abortTask(p1: NewTaskAttempContext) = () + def abortTask(p1: NewTaskAttempContext): Unit = () } class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { - def checkOutputSpecs(p1: NewJobContext) = () + def checkOutputSpecs(p1: NewJobContext): Unit = () def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { new NewFakeWriter() @@ -735,7 +736,7 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false - def setConf(p1: Configuration) = { + def setConf(p1: Configuration): Unit = { setConfCalled = true () } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index cd193ae4f5238..1880364581c1a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -100,7 +100,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[Range])) } @@ -108,7 +108,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[Range])) } @@ -139,7 +139,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(i).isInstanceOf[Range]) val range = slices(i).asInstanceOf[Range] assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.end === (i + 1) * (N / 40), "slice " + i + " end") assert(range.step === 1, "slice " + i + " step") } } @@ -156,7 +156,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size /n + 1)) } check(prop) } @@ -174,7 +174,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -192,7 +192,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -209,7 +209,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -217,7 +217,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -225,7 +225,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 8408d7e785c65..465068c6cbb16 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.{Partition, SharedSparkContext, TaskContext} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - test("Pruned Partitions inherit locality prefs correctly") { val rdd = new RDD[Int](sc, Nil) { @@ -74,8 +73,6 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { } class TestPartition(i: Int, value: Int) extends Partition with Serializable { - def index = i - - def testValue = this.value - + def index: Int = i + def testValue: Int = this.value } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index a0483886f8db3..0d1369c19c69e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -35,7 +35,7 @@ class MockSampler extends RandomSampler[Long, Long] { Iterator(s) } - override def clone = new MockSampler + override def clone: MockSampler = new MockSampler } class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index bede1ffb3e2d0..df42faab64505 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -82,7 +82,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("countApproxDistinct") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble val size = 1000 val uniformDistro = for (i <- 1 to 5000) yield i % size @@ -100,7 +100,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("partitioner aware union") { - def makeRDDWithPartitioner(seq: Seq[Int]) = { + def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) .map(x => (x, null)) .partitionBy(new HashPartitioner(2)) @@ -159,8 +159,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("treeAggregate") { val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 + def seqOp: (Long, Int) => Long = (c: Long, x: Int) => c + x + def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) assert(sum === -1000L) @@ -204,7 +204,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(empty.collect().size === 0) val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) + empty.reduce(_ + _) } assert(thrown.getMessage.contains("empty")) @@ -321,7 +321,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) val coalesced1 = data.coalesce(3) assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") @@ -921,15 +921,17 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("task serialization exception should not hang scheduler") { class BadSerializable extends Serializable { @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + private def writeObject(out: ObjectOutputStream): Unit = + throw new KryoException("Bad serialization") @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = {} } - // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were - // more threads in the Spark Context than there were number of objects in this sequence. + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if + // there were more threads in the Spark Context than there were number of objects in this + // sequence. intercept[Throwable] { - sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect() } // Check that the context has not crashed sc.parallelize(1 to 100).map(x => x*2).collect diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index 4762fc17855ce..fe695d85e29dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = a.age compare b.age + def compare(a:Person, b:Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = + def compare(a:Person, b:Person): Int = implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 0000000000000..ada07ef11cd7a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} + +import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{SparkException, SparkConf} + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + env = createRpcEnv(conf, "local", 12345) + } + + override def afterAll(): Unit = { + if(env != null) { + env.shutdown() + } + } + + def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint("send-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + + val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithReply[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint("ask-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("ask a message timeout") { + env.setupEndpoint("ask-timeout", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + Thread.sleep(100) + context.reply(msg) + } + } + }) + + val conf = new SparkConf() + conf.set("spark.akka.retry.wait", "0") + conf.set("spark.akka.num.retries", "1") + val anotherEnv = createRpcEnv(conf, "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + try { + val e = intercept[Exception] { + rpcEndpointRef.askWithReply[String]("hello", 1 millis) + } + assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint("onError-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint("self-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => { + self + callSelfSuccessfully = true + } + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in onStop") { + @volatile var selfOption: Option[RpcEndpointRef] = null + + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + selfOption = Option(self) + } + + override def onError(cause: Throwable): Unit = { + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for(i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run() { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(5 millis)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.sendFailure(new SparkException("Oops")) + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new SparkException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("network events") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress))) + } + } + + test("sendWithReply: unserializable error") { + env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + intercept[TimeoutException] { + Await.result(f, 1 seconds) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + +} + +class UnserializableClass + +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 0000000000000..58214c0637235 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") + assert("akka.tcp://local@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.shutdown() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 63360a0f189a3..3c52a8c4460c6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -57,20 +57,18 @@ class MyRDD( locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") - override def getPartitions = (0 until numPartitions).map(i => new Partition { - override def index = i + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i }).toArray override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil + if (locations.isDefinedAt(split.index)) locations(split.index) else Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite + extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -209,7 +207,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -269,21 +268,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(new MyRDD(sc, 1, Nil), Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job") { val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" + override def getPartitions: Array[Partition] = + Array( new Partition { override def index: Int = 0 } ) + override def getPreferredLocations(split: Partition): List[String] = Nil + override def toString: String = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job oom") { @@ -295,9 +296,10 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial job w/ dependency") { @@ -306,7 +308,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cache location preferences w/ dependency") { @@ -319,7 +321,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("regression test for getCacheLocs") { @@ -335,7 +337,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("avoid exponential blowup when getting preferred locs list") { - // Build up a complex dependency graph with repeated zip operations, without preferred locations. + // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. @@ -357,7 +359,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job failure") { @@ -367,7 +369,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job cancellation") { @@ -378,7 +380,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("job cancellation no-kill backend") { @@ -387,18 +389,20 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar val noKillTaskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { + override def start(): Unit = {} + override def stop(): Unit = {} + override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} + override def defaultParallelism(): Int = 2 + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} } val noKillScheduler = new DAGScheduler( @@ -422,7 +426,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // When the task set completes normally, state should be correctly updated. complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.isEmpty) @@ -442,7 +446,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with fetch failure") { @@ -465,10 +469,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial shuffle with multiple fetch failures") { @@ -521,19 +526,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run shuffle with map stage failure") { @@ -552,7 +561,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.toSet === Set(0)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -586,7 +595,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar class FailureRecordingJobListener() extends JobListener { var failureMessage: String = _ override def taskSucceeded(index: Int, result: Any) {} - override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage } } val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() @@ -606,7 +615,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with out-of-band failure and retry") { @@ -629,7 +638,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("recursive shuffle failures") { @@ -658,7 +667,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cached post-shuffle") { @@ -690,7 +699,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { @@ -742,7 +751,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("accumulator not calculated for resubmitted result stage") { - //just for register + // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) val finalRdd = new MyRDD(sc, 1, Nil) submit(finalRdd, Array(0)) @@ -754,7 +763,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(accVal === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -774,7 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - private def assertDataStructuresEmpty = { + private def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -783,6 +792,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } // Nothing in this test should break if the task info's fields are null, but diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 448258a754153..6d25edb7d20dc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -61,7 +61,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -95,7 +95,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef } test("Log overwriting") { - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, "test") + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") val logPath = new URI(logUri).getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() @@ -107,16 +107,19 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Event log name") { // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath("/base-dir", "app1")) + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) // with compression assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath("/base-dir", "app1", Some("lzf"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) // illegal characters in app ID assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1")) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) // illegal characters in app ID with compression assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1", Some("lz4"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) } /* ----------------- * @@ -137,7 +140,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -173,12 +176,15 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath - val expectedLogDir = testDir.toURI().toString() + val expectedLogDir = testDir.toURI() assert(eventLogPath === EventLoggingListener.getLogPath( expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) @@ -262,7 +268,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef object EventLoggingListenerSuite { /** Get a SparkConf with event logging enabled. */ - def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None) = { + def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") @@ -274,5 +280,5 @@ object EventLoggingListenerSuite { conf } - def getUniqueApplicationId = "test-" + System.currentTimeMillis + def getUniqueApplicationId: String = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 6b75c98839e03..9b92f8de56759 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -24,7 +24,9 @@ import org.apache.spark.TaskContext /** * A Task implementation that fails to serialize. */ -private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) + extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index c8c957856247a..cf97707946706 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -161,6 +161,31 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { } assert(tempDir.list().size === 0) } + + test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { + val stage: Int = 1 + val partition: Long = 2 + val authorizedCommitter: Long = 3 + val nonAuthorizedCommitter: Long = 100 + outputCommitCoordinator.stageStart(stage) + assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + // The non-authorized committer fails + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + // New tasks should still not be able to commit because the authorized committer has not failed + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + // The authorized committer now fails, clearing the lock + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + // A new task should now be allowed to become the authorized committer + assert( + outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 601694f57aad0..6de6d2fec622a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -145,7 +146,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 627c9a4ddfffc..825c616c0c3e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -85,7 +85,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val stopperReturned = new Semaphore(0) class BlockingListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() drained = true @@ -206,8 +206,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc.addSparkListener(new StatsReportListener) // just to make sure some of the tasks take a noticeable amount of time val w = { i: Int => - if (i == 0) + if (i == 0) { Thread.sleep(100) + } i } @@ -247,12 +248,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) + taskMetrics.resultSize should be > (0L) if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) @@ -260,7 +261,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sm.totalBlocksFetched should be (128) sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) + sm.remoteBytesRead should be (0L) } } } @@ -406,12 +407,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val startedGettingResultTasks = new mutable.HashSet[Int]() val endedTasks = new mutable.HashSet[Int]() - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { startedTasks += taskStart.taskInfo.index notify() } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { endedTasks += taskEnd.taskInfo.index notify() } @@ -425,7 +426,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers * A simple listener that throws an exception on job end. */ private class BadListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { throw new Exception } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } } @@ -438,10 +439,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ private class BasicJobCounter extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index add13f5b21765..ffa4381969b68 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import java.util.Properties - import org.scalatest.FunSuite import org.apache.spark._ @@ -27,7 +25,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def start() {} def stop() {} def reviveOffers() {} - def defaultParallelism() = 1 + def defaultParallelism(): Int = 1 } class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { @@ -115,7 +113,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } val numFreeCores = 1 taskScheduler.setDAGScheduler(dagScheduler) - var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -123,7 +122,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(0 === taskDescriptions.length) // Now check that we can still submit tasks - // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + // Even if one of the tasks has not-serializable tasks, the other task set should + // still be processed without error taskScheduler.submitTasks(taskSet) taskScheduler.submitTasks(FakeTask.createTaskSet(1)) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 12330d8f63c40..6198cea46ddf8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} import java.util.Random import scala.collection.mutable.ArrayBuffer @@ -27,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -67,7 +66,7 @@ object FakeRackUtil { hostToRack(host) = rack } - def getRackForHost(host: String) = { + def getRackForHost(host: String): Option[String] = { hostToRack.get(host) } } @@ -152,7 +151,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { private val conf = new SparkConf - val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) + val LOCALITY_WAIT_MS = conf.getTimeAsMs("spark.locality.wait", "3s") val MAX_TASK_FAILURES = 4 override def beforeEach() { @@ -240,7 +239,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) @@ -251,7 +250,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) } @@ -292,7 +291,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host1 again: nothing should get chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1 again: second task (on host2) should get chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -306,7 +305,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Now that we've launched a local task, we should no longer launch the task for host3 assert(manager.resourceOffer("exec2", "host2", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // After another delay, we can go ahead and launch that task non-locally assert(manager.resourceOffer("exec2", "host2", ANY).get.index === 3) @@ -327,8 +326,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // After this, nothing should get chosen, because we have separated tasks with unavailable preference - // from the noPrefPendingTasks + // After this, nothing should get chosen, because we have separated tasks with unavailable + // preference from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead @@ -338,7 +337,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT * 2) + clock.advance(LOCALITY_WAIT_MS * 2) // task 1 and 2 would be scheduled as nonLocal task assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -499,7 +498,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sched.addExecutor("execC", "host2") manager.executorAdded() // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") manager.executorLost("execC", "host2") @@ -527,7 +527,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) // Set allowed locality to ANY - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // Offer host3 // No task is scheduled if we restrict locality to RACK_LOCAL assert(manager.resourceOffer("execC", "host3", RACK_LOCAL) === None) @@ -569,7 +569,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) intercept[TaskNotSerializableException] { @@ -582,7 +583,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) - def genBytes(size: Int) = { (x: Int) => + def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => val bytes = Array.ofDim[Byte](size) scala.util.Random.nextBytes(bytes) bytes @@ -605,7 +606,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2"), TaskLocation("host1")), @@ -619,19 +621,21 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1) manager.speculatableTasks += 1 - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // schedule the nonPref task assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2) // schedule the speculative task assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1) - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // schedule non-local tasks assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) } - test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + test("node-local tasks should be scheduled right away " + + "when there are only node-local and no-preference tasks") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -650,7 +654,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } - test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") { + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") + { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) val taskSet = FakeTask.createTaskSet(4, @@ -710,13 +715,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) assert(manager.resourceOffer("execA", "host1", ANY) !== None) - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") manager.executorLost("execA", "host1") manager.executorLost("execB.2", "host2") - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index f1a4380d349b3..a311512e82c5e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -63,16 +63,18 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo // uri is null. val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo.getCommand.getValue === + s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo1.getCommand.getValue === + s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } test("mesos resource offers result in launching tasks") { - def createOffer(id: Int, mem: Int, cpu: Int) = { + def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -82,8 +84,10 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() } val driver = mock[SchedulerDriver] diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 6198df84fab3d..b070a54aa989b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -106,7 +106,9 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one",2->"two",3->"three"))) } test("ranges") { @@ -169,7 +171,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("kryo with collect") { val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + val result = sc.parallelize(control, 2) + .map(new ClassWithoutNoArgConstructor(_)) + .collect() + .map(_.x) assert(control === result.toSeq) } @@ -237,7 +242,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { // Set a special, broken ClassLoader and make sure we get an exception on deserialization ser.setDefaultClassLoader(new ClassLoader() { - override def loadClass(name: String) = throw new UnsupportedOperationException + override def loadClass(name: String): Class[_] = throw new UnsupportedOperationException }) intercept[UnsupportedOperationException] { ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) @@ -287,14 +292,14 @@ object KryoTest { class ClassWithNoArgConstructor { var x: Int = 0 - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithNoArgConstructor => x == c.x case _ => false } } class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithoutNoArgConstructor => x == c.x case _ => false } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index d037e2c19a64d..433fd6bb4a11d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { - def op[T](x: T) = x.toString + def op[T](x: T): String = x.toString - def pred[T](x: T) = x.toString.length % 2 == 0 + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { - def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + def fixture: (RDD[String], UnserializableClass) = { + (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + } test("throws expected serialization exceptions on actions") { val (data, uc) = fixture diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 0ade1bab18d7e..963264cef3a71 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag * A serializer implementation that always return a single element in a deserialization stream. */ class TestSerializer extends Serializer { - override def newInstance() = new TestSerializerInstance + override def newInstance(): TestSerializerInstance = new TestSerializerInstance } @@ -36,7 +36,8 @@ class TestSerializerInstance extends SerializerInstance { override def serializeStream(s: OutputStream): SerializationStream = ??? - override def deserializeStream(s: InputStream) = new TestDeserializationStream + override def deserializeStream(s: InputStream): TestDeserializationStream = + new TestDeserializationStream override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 6790388f96603..7d76435cd75e7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -54,7 +54,7 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val shuffleBlockManager = - SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockManager] val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) @@ -85,8 +85,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { // Now comes the test : // Write to shuffle 3; and close it, but before registering it, check if the file lengths for // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is concurrent read - // and writes happening to the same shuffle group. + // of block based on remaining data in file : which could mess things up when there is + // concurrent read and writes happening to the same shuffle group. val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8597997..ffa5162a31841 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,32 +68,29 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") // to make a replication attempt to inactive store fail fast - conf.set("spark.core.connection.ack.wait.timeout", "1") + conf.set("spark.core.connection.ack.wait.timeout", "1s") // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5abe..545722b050ee8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -66,34 +60,31 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,16 +99,18 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, false, 3) - val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object + // this should return the same object as level1 + val level2 = StorageLevel(false, false, false, false, 3) + // this should return a different object + val level3 = StorageLevel(false, false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -148,6 +141,12 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } + test("BlockManagerId.isDriver() backwards-compatibility with legacy driver ids (SPARK-6716)") { + assert(BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(BlockManagerId(SparkContext.LEGACY_DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver) + } + test("master + 1 manager interaction") { store = makeBlockManager(20000) val a1 = new Array[Byte](4000) @@ -357,10 +356,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +782,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -807,7 +804,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Create a non-trivial (not all zeros) byte array var counter = 0.toByte - def incr = {counter = (counter + 1).toByte; counter;} + def incr: Byte = {counter = (counter + 1).toByte; counter;} val bytes = Array.fill[Byte](1000)(incr) val byteBuffer = ByteBuffer.wrap(bytes) @@ -961,8 +958,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size + === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size + === 1) // insert some more blocks store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -970,8 +969,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size + === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size + === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => @@ -1095,8 +1096,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // Unroll with plenty of space. This should succeed and cache both blocks. @@ -1149,8 +1150,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val diskStore = store.diskStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) store.putIterator("b1", smallIterator, memAndDisk) @@ -1192,7 +1193,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memOnly = StorageLevel.MEMORY_ONLY val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // All unroll memory used is released because unrollSafely returned an array diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index 82a82e23eecf2..b47157f8331cc 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -47,7 +47,7 @@ class LocalDirsSuite extends FunSuite with BeforeAndAfter { assert(!new File("/NONEXISTENT_DIR").exists()) // SPARK_LOCAL_DIRS is a valid directory: class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 0d155982a8c54..1cb594633f331 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -137,7 +137,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before new SparkContext(conf) } - def hasKillLink = find(className("kill-link")).isDefined + def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index c0c28cb60e21d..21d8267114133 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -269,7 +269,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) val execId = "exe-1" - def makeTaskMetrics(base: Int) = { + def makeTaskMetrics(base: Int): TaskMetrics = { val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() @@ -291,7 +291,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics } - def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = finishTime diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index e1bc1379b5d80..3744e479d2f05 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -107,7 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + val stageInfo0 = new StageInfo( + 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6250d50fb7036..bec79fc4dc8f7 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException -import scala.concurrent.Await -import scala.util.{Failure, Try} - -import akka.actor._ - +import akka.actor.ActorNotFound import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId import org.apache.spark.SSLSampleConfigs._ @@ -39,39 +36,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("remote fetch security bad password") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off") { @@ -81,28 +76,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -120,8 +111,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security pass") { @@ -131,15 +122,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -148,13 +138,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = goodconf, securityManager = securityManagerGood) + val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -170,47 +157,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off client") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on") { @@ -218,26 +203,22 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -255,8 +236,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -267,28 +248,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -305,45 +282,43 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on and security enabled - bad credentials") { val conf = sparkSSLConfig() + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + slaveConf.set("spark.rpc", "akka") slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -352,35 +327,30 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) - - result match { - case Failure(ex: ActorNotFound) => - case Failure(ex: TimeoutException) => - case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") + try { + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + fail("should receive either ActorNotFound or TimeoutException") + } catch { + case e: ActorNotFound => + case e: TimeoutException => } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 054ef54e746a5..c47162779bbba 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -83,7 +83,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -95,7 +95,7 @@ class TestClass extends Serializable { } class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -164,7 +164,7 @@ object TestObjectWithNesting { } class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y + def getY: Int = y def run(): Int = { var nonSer = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 1026cb2aa7cae..47b535206c949 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -203,4 +203,76 @@ class EventLoopSuite extends FunSuite with Timeouts { assert(!eventLoop.isActive) } } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 43b6a405cb68c..c05317534cddf 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -109,7 +109,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // verify whether the earliest file has been deleted val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted - logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + + rolledOverFiles.mkString("\n")) assert(rolledOverFiles.size > 2) val earliestRolledOverFile = rolledOverFiles.head val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( @@ -135,7 +136,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - //assert(appender.getClass === classTag[ExpectedAppender].getClass) + // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { @@ -153,9 +154,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { import RollingFileAppender._ - def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) - def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) - def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + def rollingStrategy(strategy: String): Seq[(String, String)] = + Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String): Seq[(String, String)] = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String): Seq[(String, String)] = + Seq(INTERVAL_PROPERTY -> interval) val msInDay = 24 * 60 * 60 * 1000L val msInHour = 60 * 60 * 1000L diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 72e81f3f1a884..403dcb03bd6e5 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -71,7 +71,7 @@ class NextIteratorSuite extends FunSuite with Matchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - override def getNext() = { + override def getNext(): Int = { if (ints.size == 0) { finished = true 0 diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 7424c2e91d4f2..67a9f75ff2187 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -98,8 +98,10 @@ class SizeEstimatorSuite // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object - assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object + // 10 pointers plus 8-byte object + assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) + // 100 pointers plus 8-byte object + assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index c1c605cdb487b..8b72fe665c214 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -63,7 +63,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } @@ -93,7 +93,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5d93086082189..fb97e650ff95c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,6 +23,7 @@ import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStr import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols +import java.util.concurrent.TimeUnit import java.util.Locale import com.google.common.base.Charsets.UTF_8 @@ -35,7 +36,50 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { + + test("timeConversion") { + // Test -1 + assert(Utils.timeStringAsSeconds("-1") === -1) + + // Test zero + assert(Utils.timeStringAsSeconds("0") === 0) + + assert(Utils.timeStringAsSeconds("1") === 1) + assert(Utils.timeStringAsSeconds("1s") === 1) + assert(Utils.timeStringAsSeconds("1000ms") === 1) + assert(Utils.timeStringAsSeconds("1000000us") === 1) + assert(Utils.timeStringAsSeconds("1m") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1min") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1h") === TimeUnit.HOURS.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1d") === TimeUnit.DAYS.toSeconds(1)) + + assert(Utils.timeStringAsMs("1") === 1) + assert(Utils.timeStringAsMs("1ms") === 1) + assert(Utils.timeStringAsMs("1000us") === 1) + assert(Utils.timeStringAsMs("1s") === TimeUnit.SECONDS.toMillis(1)) + assert(Utils.timeStringAsMs("1m") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1min") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1h") === TimeUnit.HOURS.toMillis(1)) + assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) + + // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600s") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600ds") + } + intercept[NumberFormatException] { + Utils.timeStringAsMs("600s This breaks") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This 123s breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -106,7 +150,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val second = 1000 val minute = second * 60 val hour = minute * 60 - def str = Utils.msDurationToString(_) + def str: (Long) => String = Utils.msDurationToString(_) val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() @@ -199,7 +243,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() - val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories + // The parent directory has two child directories + val child1: File = Utils.createTempDir(parent.getCanonicalPath) val child2: File = Utils.createTempDir(parent.getCanonicalPath) val child3: File = Utils.createTempDir(child1.getCanonicalPath) // set the last modified time of child1 to 30 secs old diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 794a55d61750b..ce2968728a996 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite @deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { - def verifyVector(vector: Vector, expectedLength: Int) = { + def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) assert(vector.elements.min > 0.0) assert(vector.elements.max < 1.0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 48f79ea651018..dff8f3ddc816f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -185,7 +185,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) - val result1 = rdd.reduceByKey(_+_).collect() + val result1 = rdd.reduceByKey(_ + _).collect() assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 72d96798b1141..9ff067f86af44 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -553,10 +553,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = - buffer1 ++= buffer2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) + : ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) @@ -633,14 +633,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: Int) = ArrayBuffer[Int](i) - def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]): ArrayBuffer[Int] = { + buf1 ++= buf2 + } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll( + (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -654,9 +657,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]): ArrayBuffer[String] = + buf1 ++= buf2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) @@ -720,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) val wrongOrdering = new Ordering[String] { - override def compare(a: String, b: String) = { + override def compare(a: String, b: String): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() h1 - h2 @@ -742,9 +746,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using aggregation and external spill to make sure ExternalSorter using // partitionKeyComparator. - def createCombiner(i: String) = ArrayBuffer(i) - def mergeValue(c: ArrayBuffer[String], i: String) = c += i - def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String): ArrayBuffer[String] = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]): ArrayBuffer[String] = + c1 ++= c2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index ef7178bcdf5c2..03f5f2d1b8528 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -28,7 +28,7 @@ import scala.language.reflectiveCalls class XORShiftRandomSuite extends FunSuite with Matchers { - def fixture = new { + def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index d888de929fdda..cc86ef45858c9 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -36,8 +36,10 @@ object SparkSqlExample { val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ import sqlContext._ - val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) + + val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)).toDF() people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() diff --git a/dev/check-license b/dev/check-license index 39943f882b6ca..10740cfdc5242 100755 --- a/dev/check-license +++ b/dev/check-license @@ -24,29 +24,27 @@ acquire_rat_jar () { JAR="$rat_jar" - if [[ ! -f "$rat_jar" ]]; then - # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if [ $(command -v curl) ]; then - curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif [ $(command -v wget) ]; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi - fi - - unzip -tq $JAR &> /dev/null - if [ $? -ne 0 ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + # Download rat launch jar if it hasn't been downloaded yet + if [ ! -f "$JAR" ]; then + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi - printf "Launching rat from ${JAR}\n" + fi + + unzip -tq "$JAR" &> /dev/null + if [ $? -ne 0 ]; then + # We failed to download + rm "$JAR" + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi } @@ -71,6 +69,11 @@ mkdir -p "$FWDIR"/lib $java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt +if [ $? -ne 0 ]; then + echo "RAT exited abnormally" + exit 1 +fi + ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then diff --git a/dev/run-tests b/dev/run-tests index 561d7fc9e7b1f..bb21ab6c9aa04 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -173,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_BUILD build/mvn $HIVE_BUILD_ARGS clean package -DskipTests else echo -e "q\n" \ - | build/sbt $HIVE_BUILD_ARGS package assembly/assembly \ + | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" fi } @@ -236,3 +236,18 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests + +echo "" +echo "=========================================================================" +echo "Running SparkR tests" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi + diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 8ab6db6925d6e..154e01255b2ef 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14 readonly BLOCK_MIMA=15 readonly BLOCK_SPARK_UNIT_TESTS=16 readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3a937b637e003..3c1c91a111357 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -55,13 +55,14 @@ TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout # To write a PR test: #+ * the file must reside within the dev/tests directory #+ * be an executable bash script -#+ * accept two arguments on the command line, the first being the Github PR long commit -#+ hash and the second the Github SHA1 hash +#+ * accept three arguments on the command line, the first being the Github PR long commit +#+ hash, the second the Github SHA1 hash, and the final the current PR hash #+ * and, lastly, return string output to be included in the pr message output that will #+ be posted to Github PR_TESTS=( "pr_merge_ability" "pr_public_classes" + "pr_new_dependencies" ) function post_message () { @@ -146,34 +147,42 @@ function send_archived_logs () { fi } +# post start message +{ + start_message="\ + [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + post_message "$start_message" +} + # Environment variable to capture PR test output pr_message="" +# Ensure we save off the current HEAD to revert to +current_pr_head="`git rev-parse HEAD`" + +echo "HEAD: `git rev-parse HEAD`" +echo "GHPRB: $ghprbActualCommit" +echo "SHA1: $sha1" # Run pull request tests for t in "${PR_TESTS[@]}"; do this_test="${FWDIR}/dev/tests/${t}.sh" - # Ensure the test is a file and is executable - if [ -x "$this_test" ]; then - echo "ghprb: $ghprbActualCommit sha1: $sha1" - this_mssg="`bash \"${this_test}\" \"${ghprbActualCommit}\" \"${sha1}\" 2>/dev/null`" + # Ensure the test can be found and is a file + if [ -f "${this_test}" ]; then + echo "Running test: $t" + this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" # Check if this is the merge test as we submit that note *before* and *after* # the tests run [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" pr_message="${pr_message}\n${this_mssg}" + # Ensure, after each test, that we're back on the current PR + git checkout -f "${current_pr_head}" &>/dev/null + else + echo "Cannot find test ${this_test}." fi done -# post start message -{ - start_message="\ - [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - start_message="${start_message}\n${merge_note}" - - post_message "$start_message" -} - # run tests { timeout "${TESTS_TIMEOUT}" ./dev/run-tests @@ -205,6 +214,8 @@ done failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi @@ -222,7 +233,7 @@ done PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" - result_message="${result_message}\n${pr_message}" + result_message="${result_message}${pr_message}" post_message "$result_message" } diff --git a/dev/scalastyle b/dev/scalastyle index 86919227ed1ab..4e03f89ed5d5d 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -18,9 +18,10 @@ # echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | build/sbt -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ - >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 test:scalastyle >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh new file mode 100755 index 0000000000000..fdfb3c62aff58 --- /dev/null +++ b/dev/tests/pr_new_dependencies.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# Arg3: Current PR Commit Hash +#+ the PR hash for the current commit +# + +ghprbActualCommit="$1" +sha1="$2" +current_pr_head="$3" + +MVN_BIN="build/mvn" +CURR_CP_FILE="my-classpath.txt" +MASTER_CP_FILE="master-classpath.txt" + +# First switch over to the master branch +git checkout -f master +# Find and copy all pom.xml files into a *.gate file that we can check +# against through various `git` changes +find -name "pom.xml" -exec cp {} {}.gate \; +# Switch back to the current PR +git checkout -f "${current_pr_head}" + +# Check if any *.pom files from the current branch are different from the master +difference_q="" +for p in $(find -name "pom.xml"); do + [[ -f "${p}" && -f "${p}.gate" ]] && \ + difference_q="${difference_q}$(diff $p.gate $p)" +done + +# If no pom files were changed we can easily say no new dependencies were added +if [ -z "${difference_q}" ]; then + echo " * This patch does not change any dependencies." +else + # Else we need to manually build spark to determine what, if any, dependencies + # were added into the Spark assembly jar + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${CURR_CP_FILE} + + # Checkout the master branch to compare against + git checkout -f master + + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${MASTER_CP_FILE} + + DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" + + if [ -z "${DIFF_RESULTS}" ]; then + echo " * This patch does not change any dependencies." + else + # Pretty print the new dependencies + added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" + removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" + + # Construct the final returned message with proper + return_mssg="" + [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" + if [ -n "${removed_deps}" ]; then + if [ -n "${return_mssg}" ]; then + return_mssg="${return_mssg}\n${removed_deps_text}" + else + return_mssg="${removed_deps_text}" + fi + fi + echo "${return_mssg}" + fi + + # Remove the files we've left over + [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" + [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" + + # Clean up our mess from the Maven builds just in case + ${MVN_BIN} clean &>/dev/null +fi diff --git a/docs/README.md b/docs/README.md index 8a54724c4beae..5852f972a051d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,19 +58,25 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools -You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory. +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) + +You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a -jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it +jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 2e88b3093652d..b92c75f90b11c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -84,6 +84,7 @@
  • Scala
  • Java
  • Python
  • +
  • R
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f54b..0ea3f8eab461b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f02..7079de546e2f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/configuration.md b/docs/configuration.md index 7fe11475212b3..d9e9e67026cbb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,9 +35,19 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually -require one to prevent any sort of starvation issues. +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +actually require one to prevent any sort of starvation issues. +Properties that specify some time duration should be configured with a unit of time. +The following format is accepted: + + 25ms (milliseconds) + 5s (seconds) + 10m or 10min (minutes) + 3h (hours) + 5d (days) + 1y (years) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -429,10 +439,10 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.io.retryWait - 5 + 5s - (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds. + (Netty only) How long to wait between retries of fetches. The maximum delay caused by retrying + is 15 seconds by default, calculated as maxRetries * retryWait. @@ -713,6 +723,17 @@ Apart from these, the following properties are also available, and may be useful this duration will be cleared as well. + + spark.executor.cores + 1 in YARN mode, all the available cores on the worker in standalone mode. + + The number of cores to use on each executor. For YARN and standalone mode only. + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one + executor per application will run on each worker. + + spark.default.parallelism @@ -732,17 +753,17 @@ Apart from these, the following properties are also available, and may be useful spark.executor.heartbeatInterval - 10000 - Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + 10s + Interval between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress tasks. spark.files.fetchTimeout - 60 + 60s Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver, in seconds. + the driver. @@ -853,11 +874,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval - 1000 + 1000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` if you need to. A likely positive use case for using failure detector would be: a sensistive failure detector can help evict rogue executors quickly. However this is usually not the case @@ -868,11 +889,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.pauses - 6000 + 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart - beat pause in seconds for Akka. This can be used to control sensitivity to GC pauses. Tune + beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune this along with `spark.akka.heartbeat.interval` if you need to. @@ -886,9 +907,9 @@ Apart from these, the following properties are also available, and may be useful spark.akka.timeout - 100 + 100s - Communication timeout between Spark nodes, in seconds. + Communication timeout between Spark nodes. @@ -938,10 +959,10 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout - 120 + 120s - Default timeout for all network interactions, in seconds. This config will be used in - place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, + Default timeout for all network interactions. This config will be used in place of + spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs or spark.shuffle.io.connectionTimeout, if they are not configured. @@ -989,9 +1010,9 @@ Apart from these, the following properties are also available, and may be useful spark.locality.wait - 3000 + 3s - Number of milliseconds to wait to launch a data-local task before giving up and launching it + How long to wait to launch a data-local task before giving up and launching it on a less-local node. The same wait will be used to step through multiple locality levels (process-local, node-local, rack-local and then any). It is also possible to customize the waiting time for each level by setting spark.locality.wait.node, etc. @@ -1024,10 +1045,9 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.maxRegisteredResourcesWaitingTime - 30000 + 30s - Maximum amount of time to wait for resources to register before scheduling begins - (in milliseconds). + Maximum amount of time to wait for resources to register before scheduling begins. @@ -1054,10 +1074,9 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.revive.interval - 1000 + 1s - The interval length for the scheduler to revive the worker resource offers to run tasks - (in milliseconds). + The interval length for the scheduler to revive the worker resource offers to run tasks. @@ -1070,9 +1089,9 @@ Apart from these, the following properties are also available, and may be useful spark.speculation.interval - 100 + 100ms - How often Spark will check for tasks to speculate, in milliseconds. + How often Spark will check for tasks to speculate. @@ -1127,10 +1146,10 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout - 600 + 600s - If dynamic allocation is enabled and an executor has been idle for more than this duration - (in seconds), the executor will be removed. For more detail, see this + If dynamic allocation is enabled and an executor has been idle for more than this duration, + the executor will be removed. For more detail, see this description. @@ -1157,10 +1176,10 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.schedulerBacklogTimeout - 5 + 5s If dynamic allocation is enabled and there have been pending tasks backlogged for more than - this duration (in seconds), new executors will be requested. For more detail, see this + this duration, new executors will be requested. For more detail, see this description. @@ -1215,18 +1234,18 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.ack.wait.timeout - 60 + 60s - Number of seconds for the connection to wait for ack to occur before timing + How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, you can set larger value. spark.core.connection.auth.wait.timeout - 30 + 30s - Number of seconds for the connection to wait for authentication to occur before timing + How long for the connection to wait for authentication to occur before timing out and giving up. @@ -1347,9 +1366,9 @@ Apart from these, the following properties are also available, and may be useful Property NameDefaultMeaning spark.streaming.blockInterval - 200 + 200ms - Interval (milliseconds) at which data received by Spark Streaming receivers is chunked + Interval at which data received by Spark Streaming receivers is chunked into blocks of data before storing them in Spark. Minimum recommended - 50 ms. See the performance tuning section in the Spark Streaming programing guide for more details. diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e754..317554c5f2a5b 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904d..1b90d7ec5a7ae 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c08c76d226713..771a07183e26f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.sql import Row, SQLContext sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index a83472f5be52e..9780ea52c4994 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation and use it for prediction. MLlib supports [multinomial naive -Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), -which is typically used for [document -classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) +and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification] +(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. -Feature values must be nonnegative to represent term frequencies. +feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). +Feature values must be nonnegative. The model type is selected with an optional parameter +"Multinomial" or "Bernoulli" with "Multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -32,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, and output a +smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -51,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0) +val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f5b775da7930a..f4fabb0927b66 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -937,7 +937,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -964,7 +964,7 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better @@ -975,25 +975,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -1006,17 +1006,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1080,7 +1080,7 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. @@ -1090,6 +1090,67 @@ for details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that is grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are not cleaned up from Spark's temporary storage until Spark is stopped, which means that +long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need +to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d9f3eb2b74b18..853c9f26b0ec9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -48,9 +48,9 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.am.waitTime - 100000 + 100s - In yarn-cluster mode, time in milliseconds for the application master to wait for the + In yarn-cluster mode, time for the application master to wait for the SparkContext to be initialized. In yarn-client mode, time for the application master to wait for the driver to connect to it. @@ -87,7 +87,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. @@ -196,6 +197,15 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.submit.waitAppCompletion + true + + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. + + # Launching Spark on YARN diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 74d8653a8b845..0eed9adacf123 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./bin/spark-class org.apache.spark.deploy.worker.Worker spark://IP:PORT + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -81,6 +81,7 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `sbin/start-slave.sh` - Starts a slave instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of slaves as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. - `sbin/stop-slaves.sh` - Stops all slave instances on the machines specified in the `conf/slaves` file. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4441d6a000a02..03500867df70f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1371,7 +1371,10 @@ the Data Sources API. The following options are supported: These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. @@ -1642,7 +1645,7 @@ moved into the udf object in `SQLContext`.
    {% highlight java %} -sqlCtx.udf.register("strLen", (s: String) => s.length()) +sqlContext.udf.register("strLen", (s: String) => s.length()) {% endhighlight %}
    @@ -1650,7 +1653,7 @@ sqlCtx.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlCtx.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); {% endhighlight %}
    @@ -1784,6 +1787,7 @@ in Hive deployments. **Esoteric Hive Features** + * `UNION` type * Unique join * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 6d6229625f3f9..262512a639046 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -704,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
    diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index c467cd08ed742..0c1f24761d0de 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -282,6 +282,10 @@ def parse_args(): parser.add_option( "--vpc-id", default=None, help="VPC to launch instances in") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -456,6 +460,13 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -700,7 +711,7 @@ def get_instances(group_names): # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: print "Generating cluster's SSH key on master..." key_setup = """ @@ -712,8 +723,9 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) print "Transferring cluster's SSH key to slaves..." for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print slave_address + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone', 'tachyon'] @@ -802,7 +814,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -916,7 +929,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -941,10 +954,12 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): print "Deploying Spark via git hash; Tachyon won't be set up" modules = filter(lambda x: x != "tachyon", modules) + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -1004,7 +1019,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): # # root_dir should be an absolute path. def deploy_user_files(root_dir, opts, master_nodes): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) command = [ 'rsync', '-rv', '-e', stringify_command(ssh_command(opts)), @@ -1115,6 +1130,20 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1223,7 +1252,7 @@ def real_main(): if any(master_nodes + slave_nodes): print "The following instances will be terminated:" for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name + print "> %s" % get_dns_name(inst, opts.private_ips) print "ALL DATA ON ALL NODES WILL BE LOST!!" msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) @@ -1287,13 +1316,17 @@ def real_main(): elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print "Logging into master " + master + "..." + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1311,7 +1344,11 @@ def real_main(): elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + print get_dns_name(master_nodes[0], opts.private_ips) elif action == "stop": response = raw_input( diff --git a/examples/pom.xml b/examples/pom.xml index 7e93f0eec0b91..afd7c6d52f0dd 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -90,6 +90,12 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + org.apache.hbase @@ -234,6 +240,7 @@ org.apache.commons commons-math3 + provided com.twitter @@ -262,6 +269,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -270,10 +293,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -281,6 +316,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -322,12 +368,6 @@ - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 1c8a20bf8f1ae..11a8cf09533ce 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -41,7 +41,7 @@ object DirectKafkaWordCount { | is a list of one or more Kafka brokers | is a list of one or more kafka topics to consume from | - """".stripMargin) + """.stripMargin) System.exit(1) } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 19d0eb216848e..eaf00d09f550d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -116,7 +116,7 @@ class MyJavaLogisticRegression */ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); - int getMaxIter() { return (Integer) get(maxIter); } + int getMaxIter() { return (Integer) getOrDefault(maxIter); } public MyJavaLogisticRegression() { setMaxIter(100); @@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) { public MyJavaLogisticRegressionModel copy() { MyJavaLogisticRegressionModel m = new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); - Params$.MODULE$.inheritValues(this.paramMap(), this, m); + Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); return m; } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index dee794840a3e1..8159ffbe2d269 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,12 +99,12 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index d281f4fa44282..c73edb7fd6b20 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -33,7 +33,7 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3c..fcbf56cbf0c52 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -44,19 +44,19 @@ def summarize(dataset): print >> sys.stderr, "Usage: dataset_example.py " exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) print "Save dataset as a Parquet file to %s." % tempdir dataset0.saveAsParquetFile(tempdir) print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 47202fde7510b..d89361f324917 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -19,7 +19,7 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 0000000000000..6e6b5cb93789c --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 0000000000000..c864a4232d010 --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 0000000000000..2a86aa98160d3 --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 0000000000000..aa7a833e147a0 --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 0000000000000..b734cb0ecf55b --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3d..f73eac1e2b906 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -40,8 +40,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea1..a55e0dc8d36c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -37,8 +37,8 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d83..32e02eab8b031 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -54,8 +54,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922a..8c01a60844620 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -42,8 +42,8 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3db..772cd897f5140 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -31,7 +31,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e322d4ce5a745..ab6e63deb3c95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -90,7 +90,7 @@ class PRMessage() extends Message[String] with Serializable { } class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = { val hash = key match { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index df26798e41b7b..2245fa429fda3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) } /** @@ -174,11 +174,11 @@ private class MyLogisticRegressionModel( * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * - * This is used for the defaul implementation of [[transform()]]. + * This is used for the default implementation of [[transform()]]. */ override protected def copy(): MyLogisticRegressionModel = { val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 1f4ca4fbe7778..0bc36ea65e1ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -178,7 +178,9 @@ object MovieLensALS { def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 9f22d40c15f3f..6d8b806569dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -65,7 +65,7 @@ object PowerIterationClusteringExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("PIC Circles") { + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") .text(s"number of circles (/clusters), default: ${defaultParams.k}") @@ -76,9 +76,9 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Int]('r', "r") + opt[Double]('r', "r") .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(numPoints = x)) + .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -154,3 +154,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } + diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a2..92867b44be138 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -85,13 +85,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817e..751b30ea15782 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd1866..e99d1baa72b9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +45,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +86,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b62671..54d996b8ac990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e99817224..42df8792f147f 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd1..60e2994431b38 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -187,8 +186,8 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - + override def preferredLocation: Option[String] = Option(host) + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * @@ -198,13 +197,12 @@ class FlumeReceiver( */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e04d4088df7dc..2edea9b5b69ba 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 51d273af8da84..39e6754c81dbf 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 04e65cb3d708c..1b1fc8051d052 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -129,8 +129,9 @@ class DirectKafkaInputDStream[ private[streaming] class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def batchForTime = data.asInstanceOf[mutable.HashMap[ - Time, Array[OffsetRange.OffsetRangeTuple]]] + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } override def update(time: Time) { batchForTime.clear() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 2f7e0ab39fefd..bd767031c1849 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -123,9 +123,17 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { val errs = new Err withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp: TopicMetadataResponse = consumer.send(req) - // error codes here indicate missing / just created topic, - // repeating on a different broker wont be useful - return Right(resp.topicsMetadata.toSet) + val respErrs = resp.topicsMetadata.filter(m => m.errorCode != ErrorMapping.NoError) + + if (respErrs.isEmpty) { + return Right(resp.topicsMetadata.toSet) + } else { + respErrs.foreach { m => + val cause = ErrorMapping.exceptionFor(m.errorCode) + val msg = s"Error getting partition metadata for '${m.topic}'. Does the topic exist?" + errs.append(new SparkException(msg, cause)) + } + } } Left(errs) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 6d465bcb6bfc0..a1b4a12e5d6a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -23,10 +23,9 @@ import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskC import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator -import java.util.Properties import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} -import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.consumer.SimpleConsumer import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties @@ -86,7 +85,7 @@ class KafkaRDD[ val part = thePart.asInstanceOf[KafkaRDDPartition] assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { - log.warn(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { @@ -155,7 +154,7 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close() = consumer.close() + override def close(): Unit = consumer.close() override def getNext(): R = { if (iter == null || !iter.hasNext) { @@ -207,7 +206,7 @@ object KafkaRDD { fromOffsets: Map[TopicAndPartition, Long], untilOffsets: Map[TopicAndPartition, LeaderOffset], messageHandler: MessageAndMetadata[K, V] => R - ): KafkaRDD[K, V, U, T, R] = { + ): KafkaRDD[K, V, U, T, R] = { val leaders = untilOffsets.map { case (tp, lo) => tp -> (lo.host, lo.port) }.toMap diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala new file mode 100644 index 0000000000000..13e9475065979 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap} +import java.util.Properties +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} +import kafka.serializer.StringEncoder +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.ZKStringSerializer +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkClient: ZkClient = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkClient = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkClient).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration) + server = new KafkaServer(brokerConf) + server.startup() + (server, port) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + AdminUtils.createTopic(zkClient, topic, 1, 1) + // wait until metadata is propagated + waitUntilMetadataIsPropagated(topic, 0) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + import scala.collection.JavaConversions._ + sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new Producer[String, String](new ProducerConfig(producerConfiguration)) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerAddress) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + eventually(Time(10000), Time(100)) { + assert( + server.apis.metadataCache.containsTopicAndPartition(topic, partition), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index d6ca6d58b5665..4c1d6a03eb2b8 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -41,24 +41,28 @@ public class JavaDirectKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); - SparkConf sparkConf = new SparkConf() - .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); - ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); } @After public void tearDown() { + if (ssc != null) { ssc.stop(); ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -74,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { sent.addAll(Arrays.asList(topic2data)); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); JavaDStream stream1 = KafkaUtils.createDirectStream( @@ -147,8 +151,8 @@ private HashMap topicOffsetToMap(String topic, Long off private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 4477b81827c70..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -37,13 +37,12 @@ public class JavaKafkaRDDSuite implements Serializable { private transient JavaSparkContext sc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); sc = new JavaSparkContext(sparkConf); @@ -51,10 +50,15 @@ public void setUp() { @After public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -66,7 +70,7 @@ public void testKafkaRDD() throws InterruptedException { String[] topic2data = createTopicAndSendData(topic2); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), @@ -75,7 +79,7 @@ public void testKafkaRDD() throws InterruptedException { HashMap emptyLeaders = new HashMap(); HashMap leaders = new HashMap(); - String[] hostAndPort = suiteBase.brokerAddress().split(":"); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); leaders.put(new TopicAndPartition(topic2, 0), broker); @@ -144,8 +148,8 @@ public String call(MessageAndMetadata msgAndMd) throws Exception private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index bad0a93eb2e84..540f4ceabab47 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -22,9 +22,7 @@ import java.util.List; import java.util.Random; -import scala.Predef; import scala.Tuple2; -import scala.collection.JavaConverters; import kafka.serializer.StringDecoder; import org.junit.After; @@ -44,13 +42,12 @@ public class JavaKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; private transient Random random = new Random(); - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); ssc = new JavaStreamingContext(sparkConf, new Duration(500)); @@ -58,10 +55,15 @@ public void setUp() { @After public void tearDown() { - ssc.stop(); - ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -75,15 +77,11 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - suiteBase.createTopic(topic); - HashMap tmp = new HashMap(sent); - suiteBase.sendMessages(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms()) - ); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, sent); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -126,6 +124,7 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); + long startTime = System.currentTimeMillis(); boolean sizeMatches = false; while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { @@ -136,6 +135,5 @@ public Void call(JavaPairRDD rdd) throws Exception { for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } - ssc.stop(); } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 17ca9d145d665..415730f5559c5 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -27,31 +27,41 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -class DirectKafkaStreamSuite extends KafkaStreamSuiteBase - with BeforeAndAfter with BeforeAndAfterAll with Eventually { +class DirectKafkaStreamSuite + extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ - var ssc: StreamingContext = _ - var testDir: File = _ + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } after { @@ -72,12 +82,12 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topics = Set("basic1", "basic2", "basic3") val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } val totalSent = data.values.sum * topics.size val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) @@ -121,9 +131,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "largest" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -132,7 +142,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() > 3) } @@ -154,7 +164,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -166,9 +176,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "offset" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -177,7 +187,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() >= 10) } @@ -200,7 +210,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -210,18 +220,18 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase // Test to verify the offset ranges can be recovered from the checkpoints test("offset recovery") { val topic = "recovery" - createTopic(topic) + kafkaTestUtils.createTopic(topic) testDir = Utils.createTempDir() val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) // Send data to Kafka and wait for it to be received def sendDataAndWaitForReceive(data: Seq[Int]) { val strings = data.map { _.toString} - sendMessages(topic, strings.map { _ -> 1}.toMap) + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) eventually(timeout(10 seconds), interval(50 milliseconds)) { assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index fc9275b7207be..7fb841b79cb65 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,31 +20,41 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} -class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val topic = "kcsuitetopic" + Random.nextInt(10000) - val topicAndPartition = TopicAndPartition(topic, 0) - var kc: KafkaCluster = null +class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { + private val topic = "kcsuitetopic" + Random.nextInt(10000) + private val topicAndPartition = TopicAndPartition(topic, 0) + private var kc: KafkaCluster = null + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { - setupKafka() - createTopic(topic) - sendMessages(topic, Map("a" -> 1)) - kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress")) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) } override def afterAll() { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("metadata apis") { val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) val leaderAddress = s"${leader._1}:${leader._2}" - assert(leaderAddress === brokerAddress, "didn't get leader") + assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") val parts = kc.getPartitions(Set(topic)).right.get assert(parts(topicAndPartition), "didn't get partitions") + + val err = kc.getPartitions(Set(topic + "BAD")) + assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") } test("leader offset apis") { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index a223da70b043f..7d26ce50875b3 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,18 +22,22 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark._ -import org.apache.spark.SparkContext._ -class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ +class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + override def beforeAll { sc = new SparkContext(sparkConf) - - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { @@ -41,17 +45,21 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { sc.stop sc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("basic usage") { val topic = "topicbasic" - createTopic(topic) + kafkaTestUtils.createTopic(topic) val messages = Set("the", "quick", "brown", "fox") - sendMessages(topic, messages.toArray) + kafkaTestUtils.sendMessages(topic, messages.toArray) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -67,15 +75,15 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val kc = new KafkaCluster(kafkaParams) // this is the "lots of messages" case - sendMessages(topic, sent) + kafkaTestUtils.sendMessages(topic, sent) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -92,14 +100,14 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // shouldn't get anything, since message is sent after rdd was defined val sentOnlyOne = Map("d" -> 1) - sendMessages(topic, sentOnlyOne) + kafkaTestUtils.sendMessages(topic, sentOnlyOne) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above val rdd3 = getRdd(kc, Set(topic)) // send lots of messages after rdd was defined, they shouldn't show up - sendMessages(topic, Map("extra" -> 22)) + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) assert(rdd3.isDefined) assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e4966eebb9b34..24699dfc33adb 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,209 +17,38 @@ package org.apache.spark.streaming.kafka -import java.io.File -import java.net.InetSocketAddress -import java.util.Properties - import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.AdminUtils -import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, Producer, ProducerConfig} -import kafka.serializer.{StringDecoder, StringEncoder} -import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient.ZkClient -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.scalatest.{BeforeAndAfter, FunSuite} +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils - -/** - * This is an abstract base class for Kafka testsuites. This has the functionality to set up - * and tear down local Kafka servers, and to push data using Kafka producers. - */ -abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - - private val zkHost = "localhost" - private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 - private val zkSessionTimeout = 6000 - private var zookeeper: EmbeddedZookeeper = _ - private val brokerHost = "localhost" - private var brokerPort = 9092 - private var brokerConf: KafkaConfig = _ - private var server: KafkaServer = _ - private var producer: Producer[String, String] = _ - private var zkReady = false - private var brokerReady = false - - protected var zkClient: ZkClient = _ - - def zkAddress: String = { - assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") - s"$zkHost:$zkPort" - } - def brokerAddress: String = { - assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") - s"$brokerHost:$brokerPort" - } - - def setupKafka() { - // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") - // Get the actual zookeeper binding port - zkPort = zookeeper.actualPort - zkReady = true - logInfo("==================== Zookeeper Started ====================") +class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { + private var ssc: StreamingContext = _ + private var kafkaTestUtils: KafkaTestUtils = _ - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) - logInfo("==================== Zookeeper Client Created ====================") - - // Kafka broker startup - var bindSuccess: Boolean = false - while(!bindSuccess) { - try { - val brokerProps = getBrokerConfig() - brokerConf = new KafkaConfig(brokerProps) - server = new KafkaServer(brokerConf) - server.startup() - logInfo("==================== Kafka Broker Started ====================") - bindSuccess = true - } catch { - case e: KafkaException => - if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { - brokerPort += 1 - } - case e: Exception => throw new Exception("Kafka server create failed", e) - } - } - - Thread.sleep(2000) - logInfo("==================== Kafka + Zookeeper Ready ====================") - brokerReady = true + override def beforeAll(): Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } - def tearDownKafka() { - brokerReady = false - zkReady = false - if (producer != null) { - producer.close() - producer = null - } - - if (server != null) { - server.shutdown() - server = null - } - - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - if (zkClient != null) { - zkClient.close() - zkClient = null - } - - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } - - def createTopic(topic: String) { - AdminUtils.createTopic(zkClient, topic, 1, 1) - // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) - logInfo(s"==================== Topic $topic Created ====================") - } - - def sendMessages(topic: String, messageToFreq: Map[String, Int]) { - val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray - sendMessages(topic, messages) - } - - def sendMessages(topic: String, messages: Array[String]) { - producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) - producer.close() - logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================") - } - - private def getBrokerConfig(): Properties = { - val props = new Properties() - props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkAddress) - props.put("log.flush.interval.messages", "1") - props.put("replica.socket.timeout.ms", "1500") - props - } - - private def getProducerConfig(): Properties = { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - val props = new Properties() - props.put("metadata.broker.list", brokerAddr) - props.put("serializer.class", classOf[StringEncoder].getName) - props - } - - private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert( - server.apis.metadataCache.containsTopicAndPartition(topic, partition), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) - } - } - - class EmbeddedZookeeper(val zkConnect: String) { - val random = new Random() - val snapshotDir = Utils.createTempDir() - val logDir = Utils.createTempDir() - - val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) - val (ip, port) = { - val splits = zkConnect.split(":") - (splits(0), splits(1).toInt) - } - val factory = new NIOServerCnxnFactory() - factory.configure(new InetSocketAddress(ip, port), 16) - factory.startup(zookeeper) - - val actualPort = factory.getLocalPort - - def shutdown() { - factory.shutdown() - Utils.deleteRecursively(snapshotDir) - Utils.deleteRecursively(logDir) - } - } -} - - -class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { - var ssc: StreamingContext = _ - - before { - setupKafka() - } - - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("Kafka input stream") { @@ -227,10 +56,10 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - sendMessages(topic, sent) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkAddress, + val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -244,14 +73,14 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { result.put(kv._1, count) } } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent.size === result.size) sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } } - ssc.stop() } } - diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 3cd960d1fd1d4..38548dd73b82c 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import java.io.File import scala.collection.mutable @@ -27,7 +26,7 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkConf @@ -35,47 +34,61 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { +class ReliableKafkaStreamSuite extends FunSuite + with BeforeAndAfterAll with BeforeAndAfter with Eventually { - val sparkConf = new SparkConf() + private val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.receiver.writeAheadLog.enable", "true") - val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private var kafkaTestUtils: KafkaTestUtils = _ - var groupId: String = _ - var kafkaParams: Map[String, String] = _ - var ssc: StreamingContext = _ - var tempDirectory: File = null + private var groupId: String = _ + private var kafkaParams: Map[String, String] = _ + private var ssc: StreamingContext = _ + private var tempDirectory: File = null + + override def beforeAll() : Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() - before { - setupKafka() groupId = s"test-consumer-${Random.nextInt(10000)}" kafkaParams = Map( - "zookeeper.connect" -> zkAddress, + "zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> groupId, "auto.offset.reset" -> "smallest" ) - ssc = new StreamingContext(sparkConf, Milliseconds(500)) tempDirectory = Utils.createTempDir() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDirectory) + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + before { + ssc = new StreamingContext(sparkConf, Milliseconds(500)) ssc.checkpoint(tempDirectory.getAbsolutePath) } after { if (ssc != null) { ssc.stop() + ssc = null } - Utils.deleteRecursively(tempDirectory) - tearDownKafka() } - test("Reliable Kafka input stream with single topic") { - var topic = "test-topic" - createTopic(topic) - sendMessages(topic, data) + val topic = "test-topic" + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -91,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter } } ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { // A basic process verification for ReliableKafkaReceiver. // Verify whether received message number is equal to the sent message number. @@ -100,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter // Verify the offset number whether it is equal to the total message number. assert(getCommitOffset(groupId, topic, 0) === Some(29L)) } - ssc.stop() } test("Reliable Kafka input stream with multiple topics") { val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. @@ -118,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) stream.foreachRDD(_ => Unit) ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { // Verify the offset for each group/topic to see whether they are equal to the expected one. topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } } - ssc.stop() } /** Getting partition offset from Zookeeper. */ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { - assert(zkClient != null, "Zookeeper client is not initialized") val topicDirs = new ZKGroupTopicDirs(groupId, topic) val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" - ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong) } } diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 24d78ecb3a97d..a19a72c58a705 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { msgTopic.publish(message) } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) } } } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da5699..7cf02d85d73d3 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088f..9a3569789d2e0 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee78..588e6bac7b14a 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b6123..6cdc9286c5d76 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1bd1f324298e7..a7fe4476cacb8 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,6 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils import com.amazonaws.auth.AWSCredentialsProvider import com.amazonaws.auth.DefaultAWSCredentialsProviderChain @@ -118,7 +119,7 @@ private[kinesis] class KinesisReceiver( * method. */ override def onStart() { - workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + workerId = Utils.localHostName() + ":" + UUID.randomUUID() credentialsProvider = new DefaultAWSCredentialsProviderChain() kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index d8be02e2023d5..23430179f12ec 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -62,7 +62,6 @@ object EdgeContext { * , _ + _) * }}} */ - def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]) = + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) } - diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb1439773..058c8c8aa1b24 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -34,12 +34,12 @@ class EdgeDirection private (private val name: String) extends Serializable { override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda44..c8790cac3d8a0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 8494d06b1cdb7..36dc7b0f86c89 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -409,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af75448374..c561570809253 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 43a3aea0f6196..c88b2f65a86cd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -70,9 +70,9 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038c..1df86449fa0c2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 349c8545bf201..33ac7b0ed6095 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,9 +71,9 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958e..859f896039047 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e06515179..2bcf8684b8b8e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 570440ba4441f..042e366a29f58 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -156,7 +156,7 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 1a7178b82e3af..3b0e1628d86b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -93,7 +93,7 @@ object SVDPlusPlus { val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => - (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) + (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) }.cache() materialize(gJoinT0) g.unpersist() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1fb..e2754ea699da9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 8d15150458d26..a570e4ed75fc3 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -38,12 +38,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { val doubleRing = ring ++ ring val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1) assert(graph.edges.count() === doubleRing.size) - assert(graph.edges.collect.forall(e => e.attr == 1)) + assert(graph.edges.collect().forall(e => e.attr == 1)) // uniqueEdges option should uniquify edges and store duplicate count in edge attributes val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut)) assert(uniqueGraph.edges.count() === ring.size) - assert(uniqueGraph.edges.collect.forall(e => e.attr == 2)) + assert(uniqueGraph.edges.collect().forall(e => e.attr == 2)) } } @@ -64,7 +64,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect().map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } @@ -75,15 +75,17 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet === - (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) + assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet + === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) } } test("partitionBy") { withSpark { sc => - def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) - def nonemptyParts(graph: Graph[Int, Int]) = { + def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = { + Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) + } + def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = { graph.edges.partitionsRDD.mapPartitions { iter => Iterator(iter.next()._2.iterator.toList) }.filter(_.nonEmpty) @@ -102,7 +104,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1) // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into // the same partition - assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) + assert( + nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) // partitionBy(EdgePartition2D) puts identical edges in the same partition assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1) @@ -140,10 +143,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val g = Graph( sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))), sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2)) - assert(g.triplets.collect.map(_.toTuple).toSet === + assert(g.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) val gPart = g.partitionBy(EdgePartition2D) - assert(gPart.triplets.collect.map(_.toTuple).toSet === + assert(gPart.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) } } @@ -154,10 +157,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = starGraph(sc, n) // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length) - assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) + assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -177,12 +180,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Trigger initial vertex replication graph0.triplets.foreach(x => {}) // Change type of replicated vertices, but preserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + val graph1 = graph0.mapVertices { case (vid, integerOpt) => + integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double) } // Access replicated vertices, exposing the erased type val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) } } @@ -202,7 +205,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet === + assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet === (1L to n).map(x => Edge(0, x, "vv")).toSet) } } @@ -211,7 +214,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) + assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -221,7 +224,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect.toSet === Set((1L, 2))) + assert(result.collect().toSet === Set((1L, 2))) } } @@ -237,7 +240,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet) // And 4 edges. - assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet) + assert(subgraph.edges.map(_.copy()).collect().toSet === + (2 to n by 2).map(x => Edge(0, x, 1)).toSet) } } @@ -273,9 +277,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 to n).flatMap(x => List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v") val star2 = doubleStar.groupEdges { (a, b) => a} - assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) === - star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int])) - assert(star2.vertices.collect.toSet === star.vertices.collect.toSet) + assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) === + star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int])) + assert(star2.vertices.collect().toSet === star.vertices.collect().toSet) } } @@ -300,21 +304,23 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3) + val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } + val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => + newOpt.getOrElse(old) + } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set if (et.srcId % 2 != 1) { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) } @@ -340,17 +346,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = - reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } + val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) + } val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect.toSet + (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === + assert(newReverseStar.vertices.map(_._2).collect().toSet === (0 to n).map(x => "v%d".format(x)).toSet) } } @@ -361,7 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) val graph = Graph(verts, edges) val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)) - .collect.toSet + .collect().toSet assert(triplets === Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a"))) } @@ -417,7 +424,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val neighborAttrSums = graph.mapReduceTriplets[Int]( et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { sc.stop() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index a3e28efc75a98..d2ad9be555770 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T) = { + def withSpark[T](f: SparkContext => T): T = { val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index c9443d11c76cf..d0a7198d691d7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.storage.StorageLevel class VertexRDDSuite extends FunSuite with LocalSparkContext { - def vertices(sc: SparkContext, n: Int) = { + private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) } @@ -52,7 +52,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -62,7 +62,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB: RDD[(VertexId, Int)] = sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -72,7 +72,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet) } } @@ -106,7 +106,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.diff(vertexB) - assert(vertexC.map(_._1).collect.toSet === (8 until 16).toSet) + assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet) } } @@ -116,11 +116,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD - assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) // leftJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) } } @@ -134,7 +134,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) => old - newOpt.getOrElse(0) } - assert(vertexC.filter(v => v._2 != 0).map(_._1).collect.toSet == (1 to 99 by 2).toSet) + assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet) } } @@ -144,11 +144,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD - assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) // innerJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) } } @@ -161,7 +161,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) => old - newVal } - assert(vertexC.filter(v => v._2 == 0).map(_._1).collect.toSet == (0 to 98 by 2).toSet) + assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet) } } @@ -171,7 +171,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n) val messageTargets = (0 to n) ++ (0 to n by 2) val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1))) - assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet === + assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet === (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet) } } @@ -183,7 +183,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) // test merge function - assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 3915be15b3434..4cc30a96408f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -32,7 +32,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -42,7 +42,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -50,8 +50,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() @@ -73,12 +73,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if (id < 10) { assert(cc === 0) @@ -120,9 +120,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { // Build the initial Graph val graph = Graph(users, relationships, defaultUser) val ccGraph = graph.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - assert(cc == 0) + assert(cc === 0) } } } // end of toy connected components diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index fc491ae327c2a..95804b07b1db0 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -19,15 +19,12 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ + object GridPageRank { - def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { + def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = { val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int]) val outDegree = Array.fill(nRows * nCols)(0) // Convert row column address into vertex ids (row major order) @@ -35,13 +32,13 @@ object GridPageRank { // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { val ind = sub2ind(r,c) - if (r+1 < nRows) { + if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r+1,c)) += ind + inNbrs(sub2ind(r + 1,c)) += ind } - if (c+1 < nCols) { + if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c+1)) += ind + inNbrs(sub2ind(r,c + 1)) += ind } } // compute the pagerank @@ -64,7 +61,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } - .map { case (id, error) => error }.sum + .map { case (id, error) => error }.sum() } test("Star PageRank") { @@ -80,12 +77,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum + }.map { case (vid, test) => test }.sum() assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) + val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) + val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) @@ -95,8 +92,6 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank - - test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -109,18 +104,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() + val referenceRanks = VertexRDD( + sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) } } // end of Grid PageRank - test("Chain PageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index df54aa37cad68..1f658c371ffcf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -34,8 +34,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) val sccGraph = graph.stronglyConnectedComponents(5) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(id === scc) } } } @@ -45,8 +45,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(0L == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(0L === scc) } } } @@ -60,13 +60,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - if (id < 3) - assert(0L == scc) - else if (id < 6) - assert(3L == scc) - else - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + if (id < 3) { + assert(0L === scc) + } else if (id < 6) { + assert(3L === scc) + } else { + assert(id === scc) + } } } } diff --git a/launcher/pom.xml b/launcher/pom.xml index 0fe2814135d88..182e5f60218db 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -52,11 +52,6 @@ mockito-all test - - org.scalatest - scalatest_${scala.binary.version} - test - org.slf4j slf4j-api diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index d8279145d8e90..b8f02b961113d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -186,12 +186,24 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - final String assembly = AbstractCommandBuilder.class.getProtectionDomain().getCodeSource(). - getLocation().getPath(); + // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as + // when running unit tests, or user code that embeds Spark and creates a SparkContext + // with a local or local-cluster master, will cause this code to be called from an + // environment where that env variable is not guaranteed to exist. + // + // For the testing case, we rely on the test code to set and propagate the test classpath + // appropriately. + // + // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. + // That duplicates some of the code in the shell scripts that look for the assembly, though. + String assembly = getenv(ENV_SPARK_ASSEMBLY); + if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + assembly = findAssembly(); + } addToClassPath(cp, assembly); - // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only - // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive File libdir; if (new File(sparkHome, "RELEASE").isFile()) { @@ -299,6 +311,30 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } + private String findAssembly() { + String sparkHome = getSparkHome(); + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); + } + + final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); + FileFilter filter = new FileFilter() { + @Override + public boolean accept(File file) { + return file.isFile() && re.matcher(file.getName()).matches(); + } + }; + File[] assemblies = libdir.listFiles(filter); + checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); + checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); + return assemblies[0].getAbsolutePath(); + } + private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 9b04732afee14..8028e42ffb483 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -30,6 +30,7 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "512m"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; + static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { @@ -274,14 +275,14 @@ static String quoteForBatchScript(String arg) { } /** - * Quotes a string so that it can be used in a command string and be parsed back into a single - * argument by python's "shlex.split()" function. - * + * Quotes a string so that it can be used in a command string. * Basically, just add simple escapes. E.g.: * original single argument : ab "cd" ef * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. */ - static String quoteForPython(String s) { + static String quoteForCommandString(String s) { StringBuilder quoted = new StringBuilder().append('"'); for (int i = 0; i < s.length(); i++) { int cp = s.codePointAt(i); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index b566507ee6061..d4cfeacb6ef18 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -52,7 +52,7 @@ public class SparkLauncher { /** Configuration key for the executor VM options. */ public static final String EXECUTOR_EXTRA_JAVA_OPTIONS = "spark.executor.extraJavaOptions"; /** Configuration key for the executor native library path. */ - public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryOptions"; + public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryPath"; /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 91dcf70f105db..a73c9c87e3126 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -17,14 +17,9 @@ package org.apache.spark.launcher; +import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -53,6 +48,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling @@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = PYSPARK_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); } else { this.allowsMixedArguments = false; } @@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { public List buildCommand(Map env) throws IOException { if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); } @@ -213,36 +228,62 @@ private List buildPySparkShellCommand(Map env) throws IO return buildCommand(env); } - // When launching the pyspark shell, the spark-submit arguments should be stored in the - // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable - // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS. checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); - // Store spark-submit arguments in an environment variable, since there's no way to pass - // them to shell.py on the comand line. StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { if (submitArgs.length() > 0) { submitArgs.append(" "); } - submitArgs.append(quoteForPython(arg)); + submitArgs.append(quoteForCommandString(arg)); } - env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString()); - - List pyargs = new ArrayList(); - pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); - String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); - if (!isEmpty(pyOpts)) { - pyargs.addAll(parseOptionString(pyOpts)); - } - - return pyargs; + env.put(submitArgsEnvVariable, submitArgs.toString()); } + private boolean isClientMode(Properties userProps) { String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index dba0203867372..1ae42eed8a3af 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -79,9 +79,9 @@ public void testWindowsBatchQuoting() { @Test public void testPythonArgQuoting() { - assertEquals("\"abc\"", quoteForPython("abc")); - assertEquals("\"a b c\"", quoteForPython("a b c")); - assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c")); + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } private void testOpt(String opts, List expected) { diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 626116a9e7477..97043a76cc612 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -98,7 +98,7 @@ public void testShellCliParser() throws Exception { parser.NAME, "appName"); - List args = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List args = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); List expected = Arrays.asList("spark-shell", "--app-arg", "bar", "--app-switch"); assertEquals(expected, args.subList(args.size() - expected.size(), args.size())); } @@ -110,7 +110,7 @@ public void testAlternateSyntaxParsing() throws Exception { parser.MASTER + "=foo", parser.DEPLOY_MODE + "=bar"); - List cmd = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); assertEquals("foo", findArgValue(cmd, parser.MASTER)); assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); @@ -153,7 +153,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = - new SparkSubmitCommandBuilder(Collections.emptyList()); + newCommandBuilder(Collections.emptyList()); launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); launcher.master = "yarn"; @@ -273,10 +273,15 @@ private boolean findInStringList(String list, String sep, String needle) { return contains(needle, list.split(sep)); } - private List buildCommand(List args, Map env) throws Exception { + private SparkSubmitCommandBuilder newCommandBuilder(List args) { SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); - return builder.buildCommand(env); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); + return builder; + } + + private List buildCommand(List args, Map env) throws Exception { + return newCommandBuilder(args).buildCommand(env); } } diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 00c20ad69cd4d..67a6a98217118 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -27,5 +27,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/mllib/pom.xml b/mllib/pom.xml index 4c183543e3fa8..5dfab36c76907 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -64,7 +64,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.11.1 + 0.11.2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index eff7ef925dfbd..d6b3503ebdd9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { */ @varargs def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + val map = ParamMap(paramPairs: _*) fit(dataset, map) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index cd84b05bfb496..a1d49095c24ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,9 +25,9 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) + this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index c4a36103303a2..8eddf79cdfe28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -47,6 +47,9 @@ abstract class PipelineStage extends Serializable with Logging { /** * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ protected def transformSchema( schema: StructType, @@ -81,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -98,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -135,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -174,14 +177,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9a5848684b179..7fb87fe452ee6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,6 +22,7 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 970e6ad5514d1..aa27a668f1695 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -106,7 +106,7 @@ class AttributeGroup private ( def getAttr(attrIndex: Int): Attribute = this(attrIndex) /** Converts to metadata without name. */ - private[attribute] def toMetadata: Metadata = { + private[attribute] def toMetadataImpl: Metadata = { import AttributeKeys._ val bldr = new MetadataBuilder() if (attributes.isDefined) { @@ -142,17 +142,24 @@ class AttributeGroup private ( bldr.build() } - /** Converts to a StructField with some existing metadata. */ - def toStructField(existingMetadata: Metadata): StructField = { - val newMetadata = new MetadataBuilder() + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, toMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) .build() - StructField(name, new VectorUDT, nullable = false, newMetadata) + } + + /** Converts to ML metadata */ + def toMetadata: Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) } /** Converts to a StructField. */ - def toStructField(): StructField = toStructField(Metadata.empty) + def toStructField: StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index c5fc89f935432..29339c98f51cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -67,8 +69,7 @@ private[spark] abstract class Classifier[ with ClassifierParams { /** @group setParam */ - def setRawPredictionCol(value: String): E = - set(rawPredictionCol, value).asInstanceOf[E] + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) } @@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21f61d80dd95a..cc8b0721cf2b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -31,8 +31,10 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold { + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) +} /** * :: AlphaComponent :: @@ -45,16 +47,15 @@ class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) - /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -67,7 +68,8 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) @@ -96,8 +98,6 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - setThreshold(0.5) - /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -119,7 +119,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -180,8 +180,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - println(s"LR.predict with threshold: ${paramMap(threshold)}") - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > getThreshold) 1 else 0 } override protected def predictProbabilities(features: Vector): Vector = { @@ -196,7 +195,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(this.extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index bd8caac855981..10404548ccfde 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * Params for probabilistic classification. */ @@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2360f4479f1c2..c865eb9fe092d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType - /** * :: AlphaComponent :: * @@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params * @group param */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = get(metricName) + def getMetricName: String = getOrDefault(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema - checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index fc4e12773c46d..b20f2fc49a8f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { * number of features * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ - def getNumFeatures: Int = get(numFeatures) + def getNumFeatures: Int = getOrDefault(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + setDefault(numFeatures -> (1 << 18)) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala new file mode 100644 index 0000000000000..decaeb0da6246 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Normalize a vector to have unit norm using the given p-norm. + */ +@AlphaComponent +class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { + + /** + * Normalization in L^p^ space, p = 2 by default. + * @group param + */ + val p = new DoubleParam(this, "p", "the p norm value") + + /** @group getParam */ + def getP: Double = getOrDefault(p) + + /** @group setParam */ + def setP(value: Double): this.type = set(p, value) + + setDefault(p -> 2.0) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { + val normalizer = new feature.Normalizer(paramMap(p)) + normalizer.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1142aa4f8e73d..1b102619b3524 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -47,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -56,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -86,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 0000000000000..4d960df357fe9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputFields = schema.fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + */ +@AlphaComponent +class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // TODO: handle unseen labels + + override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { + val map = extractParamMap(paramMap) + val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val model = new StringIndexerModel(this, map, labels) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StringIndexer]]. + */ +@AlphaComponent +class StringIndexerModel private[ml] ( + override val parent: StringIndexer, + override val fittingParamMap: ParamMap, + labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + // TODO: handle unseen labels + throw new SparkException(s"Unseen label: $label.") + } + } + val outputColName = map(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toStructField().metadata + dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 68401e36950bd..376a004858b4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -56,39 +56,39 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param for minimum token length, default is one to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = get(minTokenLength) + def getMinTokenLength: Int = getOrDefault(minTokenLength) /** * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam( - this, "gaps", "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = get(gaps) + def getGaps: Boolean = getOrDefault(gaps) /** * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[String] = new Param( - this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern) + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => val re = paramMap(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala new file mode 100644 index 0000000000000..e567e069e7c0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: AlphaComponent :: + * A feature transformer than merge multiple columns into a vector column. + */ +@AlphaComponent +class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val assembleFunc = udf { r: Row => + VectorAssembler.assemble(r.toSeq: _*) + } + val schema = dataset.schema + val inputColNames = map(inputCols) + val args = inputColNames.map { c => + schema(c).dataType match { + case DoubleType => UnresolvedAttribute(c) + case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) + case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + } + } + dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val inputColNames = map(inputCols) + val outputColName = map(outputCol) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NativeType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + } +} + +@AlphaComponent +object VectorAssembler { + + private[feature] def assemble(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 0000000000000..452faa06e2021 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, + Attribute, AttributeGroup} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take." + + " If a feature is found to have > maxCategories values, then it is declared continuous.") + + /** @group getParam */ + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) +} + +/** + * :: AlphaComponent :: + * + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@AlphaComponent +class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = { + require(value > 1, + s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") + set(maxCategories, value) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val firstRow = dataset.select(map(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val maxCats = map(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) + } +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + while (i < dv.size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + while (vecIndex < sv.size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: AlphaComponent :: + * + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@AlphaComponent +class VectorIndexerModel private[ml] ( + override val parent: VectorIndexer, + override val fittingParamMap: ParamMap, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + private val transformFunc: Vector => Vector = { + val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val f: Vector => Vector = { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val newField = prepOutputField(dataset.schema, map) + val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) + // For now, just check the first row of inputCol for vector length. + val firstRow = dataset.select(map(inputCol)).take(1) + if (firstRow.length != 0) { + val actualNumFeatures = firstRow(0).getAs[Vector](0).size + require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length $actualNumFeatures") + } + dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(map.contains(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema, map) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @param map Parameter map (with this class' embedded parameter map folded in) + * @return Output column field + */ + private def prepOutputField(schema: StructType, map: ParamMap): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index dfb89cc8d4af3..195333a5cc47f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -53,14 +55,14 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - checkInputColumn(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) } - addOutputColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) } } @@ -98,7 +100,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -141,7 +143,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -201,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17ece897a6c55..849c60433c777 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable -import org.apache.spark.sql.types.{DataType, StructField, StructType} - /** * :: AlphaComponent :: @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -55,58 +49,55 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { - s"$name: $doc" + "(undefined)" } + s"$name: $doc $valueStr" } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) - extends Param[Double](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) - extends Param[Int](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) - extends Param[Float](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) - extends Param[Long](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) - extends Param[Boolean](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -124,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -153,25 +147,29 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - protected def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) + protected final def set[T](param: Param[T], value: T): this.type = { + shouldOwn(param) paramMap.put(param.asInstanceOf[Param[Any]], value) this } @@ -179,44 +177,102 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** - * Gets the value of a parameter in the embedded param map. + * Optionally returns the user-supplied value of a param. */ - protected def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) } /** - * Internal param map. + * Clears the user-supplied value for the input param. */ - protected val paramMap: ParamMap = ParamMap.empty + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } /** - * Check whether the given schema contains an input column. - * @param colName Parameter name for the input column. - * @param dataType SQL DataType of the input column. + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get } - protected def addOutputColumn( - schema: StructType, - colName: String, - dataType: DataType): StructType = { - if (colName.length == 0) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") - val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) - StructType(outputFields) + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value + */ + protected final def setDefault[T](param: Param[T], value: T): this.type = { + shouldOwn(param) + defaultParamMap.put(param, value) + this + } + + /** + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. + */ + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Gets the default value of a parameter. + */ + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) + } + + /** + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + */ + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap + } + + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") } } @@ -253,12 +309,13 @@ private[spark] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -280,12 +337,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -293,10 +355,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -308,6 +367,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -317,7 +383,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) @@ -329,7 +395,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. @@ -355,7 +421,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Number of param pairs in this set. + * Number of param pairs in this map. */ def size: Int = map.size } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 0000000000000..95d7e64790c79 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name$defaultValueDoc. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + |$setDefault + | /** @group getParam */ + | final def get$Name: $T = getOrDefault($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 0000000000000..72b08bf276483 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = getOrDefault(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol (default: "features"). + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = getOrDefault(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol (default: "label"). + */ +@DeveloperApi +trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = getOrDefault(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol (default: "prediction"). + */ +@DeveloperApi +trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = getOrDefault(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol (default: "probability"). + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = getOrDefault(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + + /** @group getParam */ + final def getThreshold: Double = getOrDefault(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = getOrDefault(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCols. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = getOrDefault(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) +} + +/** + * :: DeveloperApi :: + * Trait for shared param fitIntercept (default: true). + */ +@DeveloperApi +trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = getOrDefault(fitIntercept) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index 5d660d1e151a7..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -/* NOTE TO DEVELOPERS: - * If you mix these parameter traits into your algorithm, please add a setter method as well - * so that users may use a builder pattern: - * val myLearner = new MyLearner().setParam1(x).setParam2(y)... - */ - -private[ml] trait HasRegParam extends Params { - /** - * param for regularization parameter - * @group param - */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - - /** @group getParam */ - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** - * param for max number of iterations - * @group param - */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - - /** @group getParam */ - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** - * param for features column name - * @group param - */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - - /** @group getParam */ - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** - * param for label column name - * @group param - */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - - /** @group getParam */ - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** - * param for prediction column name - * @group param - */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - - /** @group getParam */ - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasRawPredictionCol extends Params { - /** - * param for raw prediction column name - * @group param - */ - val rawPredictionCol: Param[String] = - new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", - Some("rawPrediction")) - - /** @group getParam */ - def getRawPredictionCol: String = get(rawPredictionCol) -} - -private[ml] trait HasProbabilityCol extends Params { - /** - * param for predicted class conditional probabilities column name - * @group param - */ - val probabilityCol: Param[String] = - new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", - Some("probability")) - - /** @group getParam */ - def getProbabilityCol: String = get(probabilityCol) -} - -private[ml] trait HasThreshold extends Params { - /** - * param for threshold in (binary) prediction - * @group param - */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - - /** @group getParam */ - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** - * param for input column name - * @group param - */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - - /** @group getParam */ - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasOutputCol extends Params { - /** - * param for output column name - * @group param - */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - - /** @group getParam */ - def getOutputCol: String = get(outputCol) -} - -private[ml] trait HasCheckpointInterval extends Params { - /** - * param for checkpoint interval - * @group param - */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") - - /** @group getParam */ - def getCheckpointInterval: Int = get(checkpointInterval) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 514b4ef98dc5b..bd793beba35b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -54,86 +55,88 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for rank of the matrix factorization. * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ - def getRank: Int = get(rank) + def getRank: Int = getOrDefault(rank) /** * Param for number of user blocks. * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ - def getNumUserBlocks: Int = get(numUserBlocks) + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** * Param for number of item blocks. * @group param */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ - def getNumItemBlocks: Int = get(numItemBlocks) + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. * @group param */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = get(implicitPrefs) + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation. * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ - def getAlpha: Double = get(alpha) + def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = get(userCol) + def getUserCol: String = getOrDefault(userCol) /** * Param for the column name for item ids. * @group param */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) + val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = get(itemCol) + def getItemCol: String = getOrDefault(itemCol) /** * Param for the column name for ratings. * @group param */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = get(ratingCol) + def getRatingCol: String = getOrDefault(ratingCol) /** * Param for whether to apply nonnegativity constraints. * @group param */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - val getNonnegative: Boolean = get(nonnegative) + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -142,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -171,7 +174,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -283,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => @@ -320,7 +323,7 @@ object ALS extends Logging { /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -332,20 +335,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -391,7 +393,7 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) + fillAtA(ne.ata, lambda) val x = NNLS.solve(ata, ne.atb, workspace) ne.reset() x.map(x => x.toFloat) @@ -420,7 +422,15 @@ object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -429,8 +439,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -444,28 +452,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.length == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -475,7 +468,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -483,7 +475,6 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } @@ -1114,6 +1105,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1121,13 +1113,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1141,7 +1143,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 65f6627a0c351..26ca7459c4fdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql.DataFrame @@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setRegParam(0.1) - setMaxIter(100) + setDefault(regParam -> 0.1, maxIter -> 100) /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2eb1dac56f1e9..4bb4ed813c006 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { + /** * param for the estimator to be cross-validated * @group param @@ -38,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = get(estimator) + def getEstimator: Estimator[_] = getOrDefault(estimator) /** * param for estimator param maps @@ -48,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) /** * param for the evaluator for selection @@ -57,17 +58,18 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = get(evaluator) + def getEvaluator: Evaluator = getOrDefault(evaluator) /** * param for number of folds for cross validation * @group param */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ - def getNumFolds: Int = get(numFolds) + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -92,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -130,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 0000000000000..0383bf0b382b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala similarity index 62% rename from core/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index 4636c4600a01a..ee933f4cfcafd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -15,15 +15,19 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD /** - * This class exists to restrict the visibility of TaskContext setters. + * A Wrapper of FPGrowthModel to provide helper method for Python */ -private [spark] object TaskContextHelper { - - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { - def unset(): Unit = TaskContext.unset() - + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 0000000000000..ecd3b16598438 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e39156734794c..ab15f0f36a14b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -58,7 +59,6 @@ import org.apache.spark.util.Utils */ private[python] class PythonMLLibAPI extends Serializable { - /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. * @param jsc Java SparkContext @@ -78,7 +78,13 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + if (model.isInstanceOf[LogisticRegressionModel]) { + val lrModel = model.asInstanceOf[LogisticRegressionModel] + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + } else { + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } } finally { data.rdd.unpersist(blocking = false) } @@ -191,9 +197,11 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) + .setValidateData(validateData) SVMAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -217,9 +225,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -243,9 +253,13 @@ private[python] class PythonMLLibAPI extends Serializable { regType: String, intercept: Boolean, corrections: Int, - tolerance: Double): JList[Object] = { + tolerance: Double, + validateData: Boolean, + numClasses: Int): JList[Object] = { val LogRegAlg = new LogisticRegressionWithLBFGS() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) + .setNumClasses(numClasses) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -345,26 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable { val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data) } - - /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python - */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) - - def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - - def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - - } - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -424,6 +419,24 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -437,9 +450,9 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. @@ -480,13 +493,15 @@ private[python] class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -520,6 +535,10 @@ private[python] class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } } /** @@ -1117,7 +1136,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e7c3599ff619c..057b628c6a586 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,6 +62,15 @@ class LogisticRegressionModel ( s" but was given weights of length ${weights.size}") } + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ @@ -74,6 +83,7 @@ class LogisticRegressionModel ( * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -84,6 +94,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. */ @Experimental def getThreshold: Option[Double] = threshold @@ -91,6 +102,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ @Experimental def clearThreshold(): this.type = { @@ -106,7 +118,6 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { - require(numFeatures == weightMatrix.size) val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { @@ -114,30 +125,9 @@ class LogisticRegressionModel ( case None => score } } else { - val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - - val weightsArray = weightMatrix match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weightMatrix.getClass}.") - } - - val margins = (0 until numClasses - 1).map { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (dataMatrix.size + 1 == dataWithBiasSize) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - margin - } - /** - * Find the one with maximum margins. If the maxMargin is negative, then the prediction - * result will be the first class. + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. * * PS, if you want to compute the probabilities for each outcome instead of the outcome * with maximum probability, remember to subtract the maxMargin from margins if maxMargin @@ -145,13 +135,20 @@ class LogisticRegressionModel ( */ var bestClass = 0 var maxMargin = 0.0 - var i = 0 - while(i < margins.size) { - if (margins(i) > maxMargin) { - maxMargin = margins(i) + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin bestClass = i + 1 } - i += 1 } bestClass.toDouble } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d60e82c410979..c9b3ff0172e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,9 +21,12 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.numerics.{exp => brzExp, log => brzLog} + import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -32,6 +35,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} + /** * Model for Naive Bayes Classifiers. * @@ -39,11 +43,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features + * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { + val theta: Array[Array[Double]], + val modelType: String) + extends ClassificationModel with Serializable with Saveable { + + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, "Multinomial") /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -53,19 +63,19 @@ class NaiveBayesModel private[mllib] ( this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM[Double](theta.length, theta(0).length) - - { - // Need to put an extra pair of braces to prevent Scala treating `i` as a member. - var i = 0 - while (i < theta.length) { - var j = 0 - while (j < theta(i).length) { - brzTheta(i, j) = theta(i)(j) - j += 1 - } - i += 1 - } + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t + + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). + private val (brzNegTheta, brzNegThetaSum) = modelType match { + case "Multinomial" => (None, None) + case "Bernoulli" => + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) + (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -77,22 +87,78 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) + modelType match { + case "Multinomial" => + labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + case "Bernoulli" => + labels (brzArgmax (brzPi + + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + } } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) - NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ - private object SaveLoadV1_0 { + private[mllib] object SaveLoadV2_0 { + + def thisFormatVersion: String = "2.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType = data.getString(3) + new NaiveBayesModel(labels, pi, theta, modelType) + } + + } + + private[mllib] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" @@ -100,7 +166,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" /** Model data for model import/export */ - case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -136,26 +205,32 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { override def load(sc: SparkContext, path: String): NaiveBayesModel = { val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName - (loadedClassName, version) match { + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val model = SaveLoadV1_0.load(sc, path) - assert(model.pi.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), - s"NaiveBayesModel.load expected $numFeatures features," + - s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") - model + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model } } @@ -167,9 +242,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { - def this() = this(1.0) +class NaiveBayes private ( + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { + + def this(lambda: Double) = this(lambda, "Multinomial") + + def this() = this(1.0, "Multinomial") /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -177,9 +257,24 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } - /** Get the smoothing parameter. Default: 1.0. */ + /** Get the smoothing parameter. */ def getLambda: Double = lambda + /** + * Set the model type using a string (case-sensitive). + * Supported options: "Multinomial" and "Bernoulli". + * (default: Multinomial) + */ + def setModelType(modelType:String): NaiveBayes = { + require(NaiveBayes.supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + this.modelType = modelType + this + } + + /** Get the model type. */ + def getModelType: String = this.modelType + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * @@ -213,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() + val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = modelType match { + case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Bernoulli" => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom @@ -236,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with i += 1 } - new NaiveBayesModel(labels, pi, theta) + new NaiveBayesModel(labels, pi, theta, modelType) } } @@ -244,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * Top-level methods for calling naive Bayes. */ object NaiveBayes { + + /* Set of modelTypes that NaiveBayes supports */ + private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * This version of the method uses a default smoothing parameter of 1.0. * @@ -264,16 +371,40 @@ object NaiveBayes { /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, "Multinomial").run(input) + } + + /** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) + * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle + * discrete count data and can be called by setting the model type to "multinomial". + * For example, it can be used with word counts or TF_IDF vectors of documents. + * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as + * Bernoulli NB. + * + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smoothing parameter + * + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be + * multinomial or bernoulli + */ + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + new NaiveBayes(lambda, modelType).run(input) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index b89f38cf5aba4..7d33df3221fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 180023922a9b0..aa53e88d59856 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,15 +17,20 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.{Logging, SparkException} +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} /** * :: Experimental :: @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} /** * :: Experimental :: @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] ( val v = powerIter(w, maxIterations) val assignments = kMeans(v, k).mapPartitions({ iter => iter.map { case (id, cluster) => - new Assignment(id, cluster) + Assignment(id, cluster) } }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging { * @param cluster assigned cluster id */ @Experimental - class Assignment(val id: Long, val cluster: Int) extends Serializable + case class Assignment(id: Long, cluster: Int) /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 59a79e5c6a4ac..b2d9053f70145 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.sql.{SQLContext, Row} /** * Entry in vocabulary @@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable { + private val model: Map[String, Array[Float]]) extends Serializable with Saveable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -432,7 +439,13 @@ class Word2VecModel private[mllib] ( if (norm1 == 0 || norm2 == 0) return 0.0 blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } - + + override protected def formatVersion = "1.0" + + def save(sc: SparkContext, path: String): Unit = { + Word2VecModel.SaveLoadV1_0.save(sc, path, model) + } + /** * Transforms a word to its vector representation * @param word a word @@ -475,7 +488,7 @@ class Word2VecModel private[mllib] ( .tail .toArray } - + /** * Returns a map of words to their vector representations. */ @@ -483,3 +496,71 @@ class Word2VecModel private[mllib] ( model } } + +@Experimental +object Word2VecModel extends Loader[Word2VecModel] { + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" + + case class Data(word: String, vector: Array[Float]) + + def load(sc: SparkContext, path: String): Word2VecModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + + val dataArray = dataFrame.select("word", "vector").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap + new Word2VecModel(word2VecMap) + } + + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val vectorSize = model.values.head.size + val numWords = model.size + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } + sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + } + + override def load(sc: SparkContext, path: String): Word2VecModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedVectorSize = (metadata \ "vectorSize").extract[Int] + val expectedNumWords = (metadata \ "numWords").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, loadedVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path) + val vectorSize = model.getVectors.values.head.size + val numWords = model.getVectors.size + require(expectedVectorSize == vectorSize, + s"Word2VecModel requires each word to be mapped to a vector of size " + + s"$expectedVectorSize, got vector of size $vectorSize") + require(expectedNumWords == numWords, + s"Word2VecModel requires $expectedNumWords words, but got $numWords") + model + case _ => throw new Exception( + s"Word2VecModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d1a174063caba..3fa5e068d16d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -87,6 +87,9 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 328dbe2ce11fa..4ef171f4f0419 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -227,7 +227,7 @@ object Vectors { * @param elements vector elements in (index, value) pairs. */ def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -235,7 +235,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -309,7 +310,8 @@ object Vectors { * @return norm in L^p^ space. */ def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs @@ -360,7 +362,8 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { - require(v1.size == v2.size, "vector dimension mismatch") + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => @@ -518,7 +521,9 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.size} indices and " + + s" ${values.size} values.") override def toString: String = "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 961111507f2c2..9a89a6f3a515f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -531,7 +531,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -540,8 +539,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -555,8 +555,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -564,8 +564,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -577,10 +577,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index ce95c063db970..cea8f3f47307b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A @@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} } /** Java-friendly version of `predictOn`. */ @@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a28..a49153bf73c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + protected var model: Option[LinearRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a9c93e181e3ce..c02c79f094b66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging { validationInput: RDD[LabeledPoint], boostingStrategy: BoostingStrategy, validate: Boolean): GradientBoostedTreesModel = { - val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) + baseLearnerWeights(0) = firstTreeWeight + val startingModel = new GradientBoostedTreesModel( + Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1)) + + var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 + var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) + // pseudo-residual for second iteration + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + var m = 1 while (m < numIterations) { timer.start(s"building tree $m") @@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging { baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) + Regression, baseLearners.slice(0, m + 1), + baseLearnerWeights.slice(0, m + 1)) + + predError = GradientBoostedTreesModel.updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + logDebug("error of gbt = " + predError.values.mean()) if (validate) { // Stop training early if // 1. Reduction in error is less than the validationTol or // 2. If the error increases, that is if the model is overfit. // We want the model returned corresponding to the best validation error. - val currentValidateError = loss.computeError(partialModel, validationInput) + + validatePredError = GradientBoostedTreesModel.updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { return new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, @@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging { } } // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } m += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 793dd664c5d5a..6f570b4e09c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -37,14 +37,12 @@ object AbsoluteError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. * The gradient with respect to F(x) is: sign(F(x) - y) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + override def gradient(prediction: Double, label: Double): Double = { + if (label - prediction < 0) 1.0 else -1.0 } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 51b1aed167b66..24ee9f3d51293 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -39,15 +39,12 @@ object LogLoss extends Loss { * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - val prediction = model.predict(point.features) - - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) + override def gradient(prediction: Double, label: Double): Double = { + - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 357869ff6b333..d3b82b752fa0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -31,13 +31,11 @@ trait Loss extends Serializable { /** * Method to calculate the gradients for the gradient boosting calculation. - * @param model Model of the weak learner. - * @param point Instance of the training dataset. + * @param prediction Predicted feature + * @param label true label. * @return Loss gradient. */ - def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double + def gradient(prediction: Double, label: Double): Double /** * Method to calculate error of the base learner for the gradient boosting calculation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index b990707ca4525..58857ae15e93e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -37,14 +37,12 @@ object SquaredError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. * The gradient with respect to F(x) is: - 2 (y - F(x)) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - 2.0 * (model.predict(point.features) - point.label) + override def gradient(prediction: Double, label: Double): Double = { + 2.0 * (prediction - label) } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 1950254b2aa6d..fef3d2acb202a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -130,35 +130,28 @@ class GradientBoostedTreesModel( val numIterations = trees.length val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) - var predictionAndError: RDD[(Double, Double)] = remappedData.map { i => - val pred = treeWeights(0) * trees(0).predict(i.features) - val error = loss.computeError(pred, i.label) - (pred, error) - } evaluationArray(0) = predictionAndError.values.mean() - // Avoid the model being copied across numIterations. val broadcastTrees = sc.broadcast(trees) - val broadcastWeights = sc.broadcast(treeWeights) - (1 until numIterations).map { nTree => predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = broadcastWeights.value(nTree) - iter.map { - case (point, (pred, error)) => { - val newPred = pred + currentTree.predict(point.features) * currentTreeWeight - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = pred + currentTree.predict(point.features) * currentTreeWeight + val newError = loss.computeError(newPred, point.label) + (newPred, newError) } } evaluationArray(nTree) = predictionAndError.values.mean() } broadcastTrees.unpersist() - broadcastWeights.unpersist() evaluationArray } @@ -166,6 +159,58 @@ class GradientBoostedTreesModel( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.predict(lp.features) + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { + case (lp, (pred, error)) => { + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + } + newPredError + } + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 0000000000000..161100134c92d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714a..71fb7f13c39c2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType("Bernoulli") + .setModelType("Multinomial"); + } } diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 3fb6e2ec46468..0dcfe5a2002dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite { intercept[NoSuchElementException] { group("abc") } - assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField())) + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField)) } test("attribute group without attributes") { @@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite { assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) - assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0f..35d8c2e16c6cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..9d09f24709e23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var data: Array[Vector] = _ + @transient var dataFrame: DataFrame = _ + @transient var normalizer: Normalizer = _ + @transient var l1Normalized: Array[Vector] = _ + @transient var l2Normalized: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + l1Normalized = Array( + Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.12765957, -0.23404255, -0.63829787), + Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), + Vectors.dense(0.625, 0.07894737, 0.29605263), + Vectors.sparse(3, Seq()) + ) + l2Normalized = Array( + Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.184549876, -0.3383414, -0.922749378), + Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), + Vectors.dense(0.897906166, 0.113419726, 0.42532397), + Vectors.sparse(3, Seq()) + ) + + val sqlContext = new SQLContext(sc) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normalized_features") + } + + def collectResult(result: DataFrame): Array[Vector] = { + result.select("normalized_features").collect().map { + case Row(features: Vector) => features + } + } + + def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after normalization.") + } + + test("Normalization with default parameter") { + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l2Normalized) + } + + test("Normalization with setter") { + normalizer.setP(1) + + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l1Normalized) + } +} + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 0000000000000..00b5d094d82f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SQLContext + +class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index bf862b912d326..d186ead8f542f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -25,10 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} @BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { - /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ - def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) -} +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ @@ -46,14 +43,14 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), - TokenizerTestData("Te,st. punct", Seq("punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer.setMinTokenLength(3) @@ -64,8 +61,8 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setGaps(true) .setMinTokenLength(0) val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) )) testRegexTokenizer(tokenizer, dataset2) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala new file mode 100644 index 0000000000000..57d0278e03639 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("assemble") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + val dv = Vectors.dense(2.0, 0.0) + assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) + assert(assemble(0.0, dv, 1.0, sv) === + Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) + for (v <- Seq(1, "a", null)) { + intercept[SparkException](assemble(v)) + intercept[SparkException](assemble(1.0, v)) + } + } + + test("VectorAssembler") { + val df = sqlContext.createDataFrame(Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + assembler.transform(df).select("features").collect().foreach { + case Row(v: Vector) => + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 0000000000000..81ef831c42e55 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + // Chose correct categorical features + assert(categoryMaps.keys.toSet === categoricalFeatures) + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce2987612378..88ea679eeaad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,25 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) @@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { @@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a71..641b64b42a5e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,20 +17,21 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) - override def validate(paramMap: ParamMap) = { - val m = this.paramMap ++ paramMap + setDefault(maxIter -> 10) + + override def validate(paramMap: ParamMap): Unit = { + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 0bb06e9e8ac9c..fc7349330cf86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite @@ -68,39 +69,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -108,41 +112,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -154,13 +133,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 0000000000000..c44cb61b34171 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5a27c7d2309c5..ea89b17b7c08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.mllib.classification import scala.util.Random +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} + import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -41,37 +44,48 @@ object NaiveBayesSuite { // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( - pi: Array[Double], // 1XC - theta: Array[Array[Double]], // CXD - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = "Multinomial", + sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 + val xi = modelType match { + case "Bernoulli" => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case "Multinomial" => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") } LabeledPoint(y, Vectors.dense(xi)) } } - private val smallPi = Array(0.5, 0.3, 0.2).map(math.log) + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Bernoulli") - private val smallTheta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 - ).map(_.map(math.log)) - - /** Binary labels, 3 features */ - private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), - theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4))) + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Multinomial") } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -85,6 +99,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel): Unit = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -93,19 +125,53 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(nb.getLambda === 3.0) } - test("Naive Bayes") { - val nPoints = 10000 + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 42, "Multinomial") + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + validateModelFit(pi, theta, model) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 17, "Multinomial") + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - val pi = NaiveBayesSuite.smallPi - val theta = NaiveBayesSuite.smallTheta + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val pi = Array(0.5, 0.3, 0.2).map(math.log) + val theta = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 45, "Bernoulli") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD) + val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + validateModelFit(pi, theta, model) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 20, "Bernoulli") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -142,19 +208,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } - test("model save/load") { - val model = NaiveBayesSuite.binaryModel + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model, load it back, and compare. + // Save model as version 1.0, load it back, and compare. try { - model.save(sc, path) + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) val sameModel = NaiveBayesModel.load(sc, path) assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) + assert(model.modelType === "Multinomial") } finally { Utils.deleteRecursively(tempDir) } @@ -172,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) } } - // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // If we serialize data directly in the task closure, the size of the serialized task + // would be greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) val predictions = model.predict(examples.map(_.features)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce9249..5683b55e8500a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 7bf250eb5a383..0f2b26d462ad2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -199,9 +199,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { test("k-means|| initialization") { case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { - @Override def compare(that: VectorWithCompare): Int = { - if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > - that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + override def compare(that: VectorWithCompare): Int = { + if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) { + -1 + } else { + 1 + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a94..15de10fd13a19 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -141,7 +141,7 @@ private[clustering] object LDASuite { (terms.toArray, termWeights.toArray) } - def tinyCorpus = Array( + def tinyCorpus: Array[(Long, Vector)] = Array( Vectors.dense(1, 3, 0, 2, 8), Vectors.dense(0, 2, 1, 0, 4), Vectors.dense(2, 3, 12, 3, 1), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6315c03a700f1..6d6fe6fe46bab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable +import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkContext import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext assert(x ~== u1(i.toInt) absTol 1e-14) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends FunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 850c9fce507cd..f90025d535e45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingKMeansSuite extends FunSuite with TestSuiteBase { - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for single center and equivalence to grand average") { // set parameters @@ -59,7 +59,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690dbd89..98a98a7599bcb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 0d2cec58e2c03..86119ec38101e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -439,4 +439,20 @@ class MatricesSuite extends FunSuite { assert(mUDT.typeName == "matrix") assert(mUDT.simpleString == "matrix") } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 6395188a0842a..63f2ea916d457 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -181,7 +181,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) - val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) + val exponential = + RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1) val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed) @@ -197,7 +198,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] { // This allows us to check that each partition has a different seed override def nextValue(): Double = seed.toDouble - override def setSeed(seed: Long) = this.seed = seed + override def setSeed(seed: Long): Unit = this.seed = seed override def copy(): MockDistro = new MockDistro } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 8775c0ca9df84..b3798940ddc38 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -203,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ + // scalastyle:off def testALS( users: Int, products: Int, @@ -216,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { numUserBlocks: Int = -1, numProductBlocks: Int = -1, negativeFactors: Boolean = true) { + // scalastyle:on + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 43d61151e2471..d6c93cc0e49cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -35,7 +35,7 @@ private object RidgeRegressionSuite { class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf5..26604dbe6c1ef 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index e957fa5d25f4c..352193a67860c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -95,16 +95,16 @@ object TestingUtils { /** * Comparison using absolute tolerance. */ - def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, - x, eps, ABS_TOL_MSG) + def absTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) /** * Comparison using relative tolerance. */ - def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, - x, eps, REL_TOL_MSG) + def relTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareVectorRightSide( @@ -166,7 +166,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareMatrixRightSide( @@ -229,7 +229,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index b0ecb33c28483..59e6c778806f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite { assert(!(17.8 ~= 17.59 absTol 0.2)) // Comparisons of numbers very close to zero, and both side of zeros - assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - - assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert( + -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) } test("Comparing vectors using relative error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) @@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite { test("Comparing vectors using absolute error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) diff --git a/network/common/pom.xml b/network/common/pom.xml index 7b51845206f4a..22c738bde6d42 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -80,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a77..0f999f5dfe8d8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..d686a951467cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 73da9b7346f4d..b6fbace509a0e 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -21,9 +21,13 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.google.common.base.Charsets; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,4 +125,66 @@ private static boolean isSymlink(File file) throws IOException { } return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } + + private static ImmutableMap timeSuffixes = + ImmutableMap.builder() + .put("us", TimeUnit.MICROSECONDS) + .put("ms", TimeUnit.MILLISECONDS) + .put("s", TimeUnit.SECONDS) + .put("m", TimeUnit.MINUTES) + .put("min", TimeUnit.MINUTES) + .put("h", TimeUnit.HOURS) + .put("d", TimeUnit.DAYS) + .build(); + + /** + * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for + * internal use. If no suffix is provided a direct conversion is attempted. + */ + private static long parseTimeString(String str, TimeUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + String suffix; + long val; + Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); + if (m.matches()) { + val = Long.parseLong(m.group(1)); + suffix = m.group(2); + } else { + throw new NumberFormatException("Failed to parse time string: " + str); + } + + // Check for invalid suffixes + if (suffix != null && !timeSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); + } catch (NumberFormatException e) { + String timeError = "Time must be specified as seconds (s), " + + "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "E.g. 50s, 100ms, or 250us."; + + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + public static long timeStringAsMs(String str) { + return parseTimeString(str, TimeUnit.MILLISECONDS); + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + public static long timeStringAsSec(String str) { + return parseTimeString(str, TimeUnit.SECONDS); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 2eaf3b71d9a49..0aef7f1987315 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -37,8 +37,11 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - int defaultTimeout = conf.getInt("spark.network.timeout", 120); - return conf.getInt("spark.shuffle.io.connectionTimeout", defaultTimeout) * 1000; + long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( + conf.get("spark.network.timeout", "120s")); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ @@ -68,7 +71,9 @@ public int numConnectionsPerPeer() { public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { return conf.getInt("spark.shuffle.sasl.timeout", 30) * 1000; } + public int saslRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. @@ -80,7 +85,9 @@ public int numConnectionsPerPeer() { * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ - public int ioRetryWaitTimeMs() { return conf.getInt("spark.shuffle.io.retryWait", 5) * 1000; } + public int ioRetryWaitTimeMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + } /** * Minimum size of a block that we should start using memory map rather than reading in through diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java similarity index 55% rename from project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala rename to network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java index 3d43c35299555..b525ed69fc9fb 100644 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -15,25 +15,41 @@ * limitations under the License. */ +package org.apache.spark.network; -package org.apache.spark.scalastyle +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; -import java.util.regex.Pattern +public class ByteArrayWritableChannel implements WritableByteChannel { -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} + private final byte[] data; + private int offset; -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList } - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() + @Override + public boolean isOpen() { + return true; + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c7194..860dd6d9b3915 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..ff985096d72d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/pom.xml b/pom.xml index b3cecd1893a06..bcc2f57f1af5d 100644 --- a/pom.xml +++ b/pom.xml @@ -141,13 +141,13 @@ 2.4.0 2.0.8 3.1.0 - 1.7.6 + 1.7.7 0.7.1 1.8.3 1.1.0 4.2.6 - 3.1.1 + 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 2.10 @@ -156,9 +156,11 @@ 3.6.3 1.8.8 2.4.4 - 1.1.1.6 + 1.1.1.7 1.1.2 + ${java.home} + ${test_classpath} + ${test.java.home} true @@ -1224,6 +1227,7 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + ${test.java.home} true @@ -1265,6 +1269,7 @@ create-source-jar jar-no-fork + test-jar-no-fork @@ -1442,7 +1447,7 @@ org.scalastyle scalastyle-maven-plugin - 0.4.0 + 0.7.0 false true @@ -1451,13 +1456,12 @@ ${basedir}/src/main/scala ${basedir}/src/test/scala scalastyle-config.xml - scalastyle-output.xml + ${basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} ${project.reporting.outputEncoding} - package check @@ -1473,6 +1477,25 @@ org.scalatest scalatest-maven-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + prepare-package + + test-jar + + + + log4j.properties + + + + + @@ -1696,6 +1719,16 @@ + + test-java-home + + env.JAVA_HOME + + + ${env.JAVA_HOME} + + + scala-2.11 @@ -1729,5 +1762,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b9f40046e15a2..1564babefa62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,10 +50,24 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference + // SPARK-6510 Add a Graph#minus method acting as Set#difference ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ac37c605de4b6..09b4976d10c26 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -import java.io.File +import java.io._ import scala.util.Properties import scala.collection.JavaConversions._ @@ -119,7 +119,9 @@ object SparkBuild extends PomBuild { lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( - javaHome := Properties.envOrNone("JAVA_HOME").map(file), + javaHome := sys.env.get("JAVA_HOME") + .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) + .map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", @@ -164,6 +166,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Package pyspark artifacts in the main assembly. */ + enable(PySparkAssembly.settings)(assembly) + /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) @@ -314,6 +319,7 @@ object Hive { } object Assembly { + import sbtassembly.AssemblyUtils._ import sbtassembly.Plugin._ import AssemblyKeys._ @@ -345,6 +351,60 @@ object Assembly { ) } +object PySparkAssembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ + + lazy val settings = Seq( + unmanagedJars in Compile += { BuildCommons.sparkHome / "python/lib/py4j-0.8.2.1-src.zip" }, + // Use a resource generator to copy all .py files from python/pyspark into a managed directory + // to be included in the assembly. We can't just add "python/" to the assembly's resource dir + // list since that will copy unneeded / unwanted files. + resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + val dst = new File(outDir, "pyspark") + if (!dst.isDirectory()) { + require(dst.mkdirs()) + } + + val src = new File(BuildCommons.sparkHome, "python/pyspark") + copy(src, dst) + } + ) + + private def copy(src: File, dst: File): Seq[File] = { + src.listFiles().flatMap { f => + val child = new File(dst, f.getName()) + if (f.isDirectory()) { + child.mkdir() + copy(f, child) + } else if (f.getName().endsWith(".py")) { + var in: Option[FileInputStream] = None + var out: Option[FileOutputStream] = None + try { + in = Some(new FileInputStream(f)) + out = Some(new FileOutputStream(child)) + + val bytes = new Array[Byte](1024) + var read = 0 + while (read >= 0) { + read = in.get.read(bytes) + if (read > 0) { + out.get.write(bytes, 0, read) + } + } + + Some(child) + } finally { + in.foreach(_.close()) + out.foreach(_.close()) + } + } else { + None + } + } + } +} + object Unidoc { import BuildCommons._ @@ -360,15 +420,15 @@ object Unidoc { packages .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/execution"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/hive/test"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } lazy val settings = scalaJavaUnidocSettings ++ Seq ( @@ -426,8 +486,10 @@ object TestSettings { fork := true, // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes // launched by the tests have access to the correct test-time classpath. - envVars in Test += ("SPARK_DIST_CLASSPATH" -> - (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")), + envVars in Test ++= Map( + "SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905e..7096b0d3ee7de 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da415..471d00bd8223f 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/python/docs/index.rst b/python/docs/index.rst index d150de9d5c502..f7eede9c3c82a 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 15101470afc07..26ece4c2c389a 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -31,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 7890d9dcaac21..50822c93faba1 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -10,7 +10,7 @@ Module contents :show-inheritance: pyspark.streaming.kafka module ----------------------------- +------------------------------ .. automodule:: pyspark.streaming.kafka :members: :undoc-members: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0a16cbd8bff62..2a5e84a7dfdb4 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,11 +29,10 @@ def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/join.py b/python/pyspark/join.py index efc1ef9396412..c3491defb2b29 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -48,7 +48,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +62,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +76,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -104,8 +104,9 @@ def make_mapper(i): rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4ff7463498cce..d7bc09fd77adb 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.1) """ super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -91,9 +92,9 @@ class LogisticRegressionModel(JavaModel): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 433b4fb5d22bf..263fe2a5bcc41 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): _java_class = "org.apache.spark.ml.feature.Tokenizer" @keyword_only - def __init__(self, inputCol="input", outputCol="output"): + def __init__(self, inputCol=None, outputCol=None): """ - __init__(self, inputCol="input", outputCol="output") + __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol="input", outputCol="output"): + def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) @inherit_doc @@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): _java_class = "org.apache.spark.ml.feature.HashingTF" @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) if __name__ == "__main__": @@ -117,9 +118,9 @@ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index e3a53dd780c4c..5c62620562a84 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -25,23 +25,21 @@ class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): - if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) self.parent = parent self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) class Params(Identifiable): @@ -52,26 +50,128 @@ class Params(Identifiable): __metaclass__ = ABCMeta - def __init__(self): - super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ return filter(lambda attr: isinstance(attr, Param), [getattr(self, x) for x in dir(self) if x != "params"]) - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) return paramMap + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ @@ -81,10 +181,18 @@ def _dummy(): dummy.uid = "undefined" return dummy - def _set_params(self, **kwargs): + def _set(self, **kwargs): """ - Sets params. + Sets user-supplied params. """ for param, value in kwargs.iteritems(): self.paramMap[getattr(self, param)] = value return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.iteritems(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py similarity index 70% rename from python/pyspark/ml/param/_gen_shared_params.py rename to python/pyspark/ml/param/_shared_params_code_gen.py index 5eb81106f116c..55f422497672f 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -32,29 +32,34 @@ # limitations under the License. #""" +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py -def _gen_param_code(name, doc, defaultValue): + +def _gen_param_code(name, doc, defaultValueStr): """ Generates Python code for a shared param class. :param name: param name :param doc: param doc - :param defaultValue: string representation of the param + :param defaultValueStr: string representation of the default value :return: code string """ # TODO: How to correctly inherit instance attributes? template = '''class Has$Name(Params): """ - Params with $name. + Mixin for param $name: $doc. """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + $name = Param(Params._dummy(), "$name", "$doc") def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) def set$Name(self, value): """ @@ -67,32 +72,29 @@ def get$Name(self): """ Gets the value of $name or its default value. """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' + return self.getOrDefault(self.$name)''' - upperCamelName = name[0].upper() + name[1:] + Name = name[0].upper() + name[1:] return template \ .replace("$name", name) \ - .replace("$Name", upperCamelName) \ + .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) + .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n" print "from pyspark.ml.param import Param, Params\n\n" shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de423..13b6749998ad0 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,23 +15,25 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations. """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) def setMaxIter(self, value): """ @@ -44,24 +46,23 @@ def getMaxIter(self): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization constant. """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization constant") def __init__(self): super(HasRegParam, self).__init__() #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) def setRegParam(self, value): """ @@ -74,24 +75,23 @@ def getRegParam(self): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ @@ -104,24 +104,23 @@ def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') def setLabelCol(self, value): """ @@ -134,24 +133,23 @@ def getLabelCol(self): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ @@ -164,24 +162,23 @@ def getPredictionCol(self): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) def setInputCol(self, value): """ @@ -194,24 +191,23 @@ def getInputCol(self): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) def setOutputCol(self, value): """ @@ -224,24 +220,23 @@ def getOutputCol(self): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ @@ -254,7 +249,4 @@ def getNumFeatures(self): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 83880a5afcd1d..d94ecfff09f66 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -124,10 +124,10 @@ def setParams(self, stages=[]): Sets params for Pipeline. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def fit(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -164,7 +164,7 @@ def __init__(self, transformers): self.transformers = transformers def transform(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e930b..3a42bcf723894 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -33,6 +33,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Transformer, Estimator, Pipeline @@ -46,7 +47,7 @@ class MockTransformer(Transformer): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None @@ -62,7 +63,7 @@ class MockEstimator(Estimator): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None self.model = None @@ -111,5 +112,52 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6f7f39c40eb5a..d3cb100a9efa5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,8 +40,8 @@ class Identifiable(object): def __init__(self): #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] def __repr__(self): return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 31a66b3d2f730..394f23c5e9b12 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -64,7 +64,7 @@ def _transfer_params_to_java(self, params, java_obj): :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 6449800d9c120..f2ef573fe9f6f 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -25,7 +25,7 @@ if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 6766f3ebb8894..2466e8ac43458 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -22,7 +22,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper from pyspark.mllib.util import Saveable, Loader, inherit_doc @@ -31,13 +31,13 @@ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LinearBinaryClassificationModel(LinearModel): +class LinearClassificationModel(LinearModel): """ - Represents a linear binary classification model that predicts to whether an - example is positive (1.0) or negative (0.0). + A private abstract class representing a multiclass classification model. + The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): - super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None def setThreshold(self, value): @@ -47,14 +47,26 @@ def setThreshold(self, value): Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater than or equal to this threshold is identified as an positive, and negative otherwise. + It is used for binary classification only. """ self._threshold = value + @property + def threshold(self): + """ + .. note:: Experimental + + Returns the threshold (if any) used for converting raw prediction scores + into 0/1 predictions. It is used for binary classification only. + """ + return self._threshold + def clearThreshold(self): """ .. note:: Experimental Clears the threshold so that `predict` will output raw prediction scores. + It is used for binary classification only. """ self._threshold = None @@ -66,7 +78,7 @@ def predict(self, test): raise NotImplementedError -class LogisticRegressionModel(LinearBinaryClassificationModel): +class LogisticRegressionModel(LinearClassificationModel): """A linear binary classification model derived from logistic regression. @@ -112,10 +124,39 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... os.removedirs(path) ... except: ... pass + >>> multi_class_data = [ + ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), + ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), + ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) + ... ] + >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) + >>> mcm.predict([0.0, 0.5, 0.0]) + 0 + >>> mcm.predict([0.8, 0.0, 0.0]) + 1 + >>> mcm.predict([0.0, 0.0, 0.3]) + 2 """ - def __init__(self, weights, intercept): + def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) + self._numFeatures = int(numFeatures) + self._numClasses = int(numClasses) self._threshold = 0.5 + if self._numClasses == 2: + self._dataWithBiasSize = None + self._weightsMatrix = None + else: + self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1) + self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1, + self._dataWithBiasSize) + + @property + def numFeatures(self): + return self._numFeatures + + @property + def numClasses(self): + return self._numClasses def predict(self, x): """ @@ -126,20 +167,38 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self._intercept - if margin > 0: - prob = 1 / (1 + exp(-margin)) + if self.numClasses == 2: + margin = self.weights.dot(x) + self._intercept + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 else: - exp_margin = exp(margin) - prob = exp_margin / (1 + exp_margin) - if self._threshold is None: - return prob - else: - return 1 if prob > self._threshold else 0 + best_class = 0 + max_margin = 0.0 + if x.size + 1 == self._dataWithBiasSize: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i][0:x.size]) + \ + self._weightsMatrix[i][x.size] + if margin > max_margin: + max_margin = margin + best_class = i + 1 + else: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i]) + if margin > max_margin: + max_margin = margin + best_class = i + 1 + return best_class def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( - _py2java(sc, self._coeff), self.intercept) + _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) java_model.save(sc._jsc.sc(), path) @classmethod @@ -148,8 +207,10 @@ def load(cls, sc, path): sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) intercept = java_model.intercept() + numFeatures = java_model.numFeatures() + numClasses = java_model.numClasses() threshold = java_model.getThreshold().get() - model = LogisticRegressionModel(weights, intercept) + model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) model.setThreshold(threshold) return model @@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.01, regType="l2", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False, + validateData=True): """ Train a logistic regression model on the given data. @@ -184,11 +246,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object): @classmethod def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4): + intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -223,6 +288,11 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType update (default: 10). :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) + :param numClasses: The number of classes (i.e., outcomes) a label can take + in Multinomial Logistic Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -237,12 +307,20 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, float(regParam), regType, bool(intercept), int(corrections), - float(tolerance)) - + float(tolerance), bool(validateData), int(numClasses)) + + if initialWeights is None: + if numClasses == 2: + initialWeights = [0.0] * len(data.first().features) + else: + if intercept: + initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) + else: + initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearBinaryClassificationModel): +class SVMModel(LinearClassificationModel): """A support vector machine. @@ -325,7 +403,8 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="l2", + intercept=False, validateData=True): """ Train a support vector machine on the given data. @@ -351,11 +430,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 0ffe092a07365..8be819aceec24 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -132,6 +132,22 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ @@ -244,6 +260,12 @@ def transform(self, x): x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ @@ -331,6 +353,12 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + class Word2Vec(object): """ @@ -373,6 +401,7 @@ def __init__(self): self.numPartitions = 1 self.numIterations = 1 self.seed = random.randint(0, sys.maxint) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -411,6 +440,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -422,7 +459,8 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), long(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 0000000000000..3aa6d79d7093c --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + """ + + def freqItemsets(self): + """ + Get the frequent itemsets of this model + """ + return self.call("getFreqItemsets") + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + :param data: The input data set, each element + contains a transaction. + :param minSupport: The minimal support level + (default: `0.3`). + :param numPartitions: The number of partitions used by parallel + FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f5aad28afda0f..a80320c52d1d0 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -173,7 +173,24 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): if isinstance(ar, basestring): @@ -292,6 +309,25 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ @@ -604,6 +640,15 @@ def toArray(self): """ raise NotImplementedError + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, basestring): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + class DenseMatrix(Matrix): """ @@ -611,13 +656,8 @@ class DenseMatrix(Matrix): """ def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) - if isinstance(values, basestring): - values = np.frombuffer(values, dtype=np.float64) - elif not isinstance(values, np.ndarray): - values = np.array(values, dtype=np.float64) + values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols - if values.dtype != np.float64: - values.astype(np.float64) self.values = values def __reduce__(self): @@ -634,6 +674,27 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def toSparse(self): + """Convert to SparseMatrix""" + indices = np.nonzero(self.values)[0] + colCounts = np.bincount(indices / self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = self.values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + return self.values[i + j * self.numRows] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and @@ -641,6 +702,82 @@ def __eq__(self, other): all(self.values == other.values)) +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols) + self.isTransposed = isTransposed + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + self.isTransposed) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + + def toDense(self): + densevals = np.reshape( + self.toArray(), (self.numRows * self.numCols), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: + def __eq__(self, other): + return np.all(self.toArray == other.toArray) + + class Matrices(object): @staticmethod def dense(numRows, numCols, values): @@ -649,6 +786,13 @@ def dense(numRows, numCols, values): """ return DenseMatrix(numRows, numCols, values) + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + def _test(): import doctest diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 1a4527b12cef2..c5c4c13dae105 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -82,14 +82,16 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.43... + 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel.predict(2,2) - 0.43... + 0.4... + >>> sameModel.predictAll(testset).collect() + [Rating(... >>> try: ... os.removedirs(path) ... except OSError: @@ -111,6 +113,12 @@ def userFeatures(self): def productFeatures(self): return self.call("getProductFeatures") + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) + class ALS(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 209f1ee473b5b..cd7310a64f4ae 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -167,13 +167,19 @@ def load(cls, sc, path): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca992..1d83e9d483f8e 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -49,6 +49,12 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 155019638f806..8eaddcf8b9b5e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -24,7 +24,7 @@ import tempfile import array as pyarray -from numpy import array, array_equal +from numpy import array, array_equal, zeros from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -36,11 +36,15 @@ else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec +from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -134,6 +138,61 @@ def test_sparse_vector_indexing(self): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEquals(sm1.numRows, 3) + self.assertEquals(sm1.numCols, 4) + self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEquals(sm1.numRows, smnew.numRows) + self.assertEquals(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEquals(sm1t.numRows, 3) + self.assertEquals(sm1t.numCols, 4) + self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + class ListTests(PySparkTestCase): @@ -347,6 +406,19 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + class VectorUDTTests(PySparkTestCase): @@ -620,6 +692,83 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +class FeatureTest(PySparkTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(PySparkTestCase): + def test_word2vec_setters(self): + data = [ + ["I", "have", "a", "pen"], + ["I", "like", "soccer", "very", "much"], + ["I", "live", "in", "Tokyo"] + ] + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) + + +class StandardScalerTests(PySparkTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index bf288d76447bd..a7a4d2aaf855b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -286,21 +286,18 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain - calculation. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 100) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -365,13 +362,10 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for - regression. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. :param impurity: Criterion used for information gain calculation. Supported values: "variance". diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c337a43c8a7fc..93e658eded9e2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -113,6 +113,7 @@ def _parse_memory(s): def _load_from_socket(port, serializer): sock = socket.socket() + sock.settimeout(3) try: sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) @@ -572,8 +573,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): @@ -594,7 +595,7 @@ def sortPartition(iterator): maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary @@ -1196,7 +1197,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1259,7 +1260,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1698,10 +1699,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1754,21 +1753,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> sorted(x.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(x.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1780,8 +1786,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ @@ -2208,13 +2233,11 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) + pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # tracking the life cycle by obj - if obj is not None: - obj._broadcast = broadcast broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) @@ -2269,12 +2292,9 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None - self._broadcast = None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6c..1ab5ce14c3531 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ffb41d02f6f6..4afa82f4b2973 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,6 +220,29 @@ def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in xrange(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%d)" % self.batchSize + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -251,7 +274,7 @@ def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and other.serializer == self.serializer and other.bestSize == self.bestSize) - def __str__(self): + def __repr__(self): return "AutoBatchedSerializer(%s)" % str(self.serializer) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 1a02fece9c5a5..81aa970a32f76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -53,9 +53,9 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = HiveContext(sc) + sqlCtx = sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = SQLContext(sc) + sqlCtx = sqlContext = SQLContext(sc) print("""Welcome to ____ __ @@ -68,7 +68,7 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__) +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd502000..8a6fc627eb383 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -148,10 +157,16 @@ def mergeCombiners(self, iterator): d[k] = comb(d[k], v) if k in d else v def iteritems(self): - """ Return the merged items ad iterator """ + """ Return the merged items as iterator """ return self.data.iteritems() +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) + + class ExternalMerger(Merger): """ @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -335,7 +340,7 @@ def _spill(self): for k, v in self.data.iteritems(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,7 +349,7 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,56 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) - def _recursive_merged_items(self, start): + return self.data.iteritems() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,6 +448,7 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random >>> l = range(1024) @@ -469,7 +461,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -515,6 +507,7 @@ def sorted(self, iterator, key=None, reverse=False): limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: batch = min(batch * 2, 10000) @@ -529,6 +522,310 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(range(100)) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + f = os.fdopen(os.dup(self._file.fileno())) + f.seek(0) + serialized = f.read() + else: + serialized = '' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id)) + self._file = open(p, "w+", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i/3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(iter(zip(k, v))) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iter(iterator) + self.next_item = None + + def __iter__(self): + return self + + def next(self): + key, value = self.next_item if self.next_item else next(self.iterator) + values = ExternalListOfList([value]) + try: + while True: + k, v = next(self.iterator) + if k != key: + self.next_item = (k, v) + break + values.append(v) + except StopIteration: + self.next_item = None + return key, values + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.iteritems(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if (size >> 20) >= self.memory_limit / 10: + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + return self.data.iteritems() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + return self.serializer.load_stream(open(p, 'r', 65536)) + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b9ffd6945ea7e..65abb24eed823 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -16,26 +16,32 @@ # """ -public classes of Spark SQL: +Important classes of Spark SQL and DataFrames: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for :class:`DataFrame` and SQL functionality. - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedData} + A distributed collection of data grouped into named columns. - L{Column} - Column is a DataFrame with a single column. + A column expression in a :class:`DataFrame`. - L{Row} - A Row of data returned by a Spark SQL query. + A row of data in a :class:`DataFrame`. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive. + - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - L{DataFrameNaFunctions} + Methods for handling missing data (null values). + - L{functions} + List of built-in functions available for :class:`DataFrame`. + - L{types} + List of data types available. """ from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 795ef0dbc4c47..e8529a8f8e3a4 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,15 +34,15 @@ except ImportError: has_pandas = False -__all__ = ["SQLContext", "HiveContext"] +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlCtx): +def _monkey_patch_RDD(sqlContext): def toDF(self, schema=None, sampleRatio=None): """ - Convert current :class:`RDD` into a :class:`DataFrame` + Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)` + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -51,38 +51,37 @@ def toDF(self, schema=None, sampleRatio=None): >>> rdd.toDF().collect() [Row(name=u'Alice', age=1)] """ - return sqlCtx.createDataFrame(self, schema, sampleRatio) + return sqlContext.createDataFrame(self, schema, sampleRatio) RDD.toDF = toDF class SQLContext(object): - """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. - """ - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. + When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, + which could be used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. - It will add a method called `toDF` to :class:`RDD`, which could be - used to convert an RDD into a DataFrame, it's a shorthand for - :func:`SQLContext.createDataFrame`. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. >>> from datetime import datetime - >>> sqlCtx = SQLContext(sc) + >>> sqlContext = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, @@ -118,6 +117,11 @@ def getConf(self, key, defaultValue): """ return self._ssql_ctx.getConf(key, defaultValue) + @property + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration.""" + return UDFRegistration(self) + def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -125,13 +129,22 @@ def registerFunction(self, name, f, returnType=StringType()): When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) @@ -173,63 +186,19 @@ def _inferSchema(self, rdd, samplingRatio=None): return schema def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") - schema = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) + return self.createDataFrame(rdd, None, samplingRatio) def applySchema(self, rdd, schema): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> from pyspark.sql.types import * - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> df.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -237,67 +206,49 @@ def applySchema(self, rdd, schema): if not isinstance(schema, StructType): raise TypeError("schema should be StructType, but got %s" % schema) - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + return self.createDataFrame(rdd, schema) def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame. - - `schema` could be :class:`StructType` or a list of column names. + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. - When `schema` is a list of column names, the type of each column - will be inferred from `rdd`. + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. - When `schema` is None, it will try to infer the column name and type - from `rdd`, which should be an RDD of :class:`Row`, or namedtuple, - or dict. + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. - If referring needed, `samplingRatio` is used to determined how many - rows will be used to do referring. The first row will be used if - `samplingRatio` is None. + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame - :param schema: a StructType or list of names of columns + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring - :return: a DataFrame >>> l = [('Alice', 1)] - >>> sqlCtx.createDataFrame(l).collect() + >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] - >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] - >>> sqlCtx.createDataFrame(d).collect() + >>> sqlContext.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) - >>> sqlCtx.createDataFrame(rdd).collect() + >>> sqlContext.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] - >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) - >>> df2 = sqlCtx.createDataFrame(person) + >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] @@ -305,11 +256,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) - >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] - >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] """ if isinstance(data, DataFrame): @@ -323,45 +274,63 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if not isinstance(data, RDD): try: # data could be list, tuple, generator ... - data = self._sc.parallelize(data) + rdd = self._sc.parallelize(data) except Exception: raise ValueError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data if schema is None: - return self.inferSchema(data, samplingRatio) + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): - first = data.first() + first = rdd.first() if not isinstance(first, (list, tuple)): raise ValueError("each row in `rdd` should be list or tuple, " "but got %r" % type(first)) row_cls = Row(*schema) - schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) - return self.applySchema(data, schema) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) - def registerDataFrameAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. - Temporary tables exist only during the lifetime of this instance of - SQLContext. + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerDataFrameAsTable(df, tableName) + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) else: raise ValueError("Can only register DataFrame as table") def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a L{DataFrame}. + """Loads a Parquet file, returning the result as a :class:`DataFrame`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -373,22 +342,17 @@ def parquetFile(self, *paths): return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() >>> shutil.rmtree(jsonFile) >>> with open(jsonFile, 'w') as f: ... f.writelines(jsonStrings) - >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> df1 = sqlContext.jsonFile(jsonFile) >>> df1.printSchema() root |-- field1: long (nullable = true) @@ -401,7 +365,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructField("field2", StringType()), ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) >>> df2.printSchema() root |-- field2: string (nullable = true) @@ -417,19 +381,16 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this - JSON dataset. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> df1 = sqlCtx.jsonRDD(json) + >>> df1 = sqlContext.jsonRDD(json) >>> df1.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2 = sqlContext.jsonRDD(json, df1.schema) >>> df2.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) @@ -439,10 +400,9 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))])) ... ]) - >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3 = sqlContext.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) - """ def func(iterator): @@ -463,11 +423,11 @@ def func(iterator): return DataFrame(df, self) def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a DataFrame. + """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Optionally, a schema can be provided as the schema of the returned DataFrame. """ @@ -493,11 +453,11 @@ def createExternalTable(self, tableName, path=None, source=None, It returns the DataFrame associated with the external table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. - Optionally, a schema can be provided as the schema of the returned DataFrame and + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. """ if path is not None: @@ -518,35 +478,35 @@ def createExternalTable(self, tableName, path=None, source=None, return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. + """Returns a :class:`DataFrame` representing the result of the given query. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{DataFrame}. + """Returns the specified table as a :class:`DataFrame`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True """ return DataFrame(self._ssql_ctx.table(tableName), self) def tables(self, dbName=None): - """Returns a DataFrame containing names of tables in the given database. + """Returns a :class:`DataFrame` containing names of tables in the given database. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. - The returned DataFrame has two columns, tableName and isTemporary - (a column with BooleanType indicating if a table is a temporary one or not). + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.tables() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() Row(tableName=u'table1', isTemporary=True) """ @@ -556,14 +516,14 @@ def tables(self, dbName=None): return DataFrame(self._ssql_ctx.tables(dbName), self) def tableNames(self, dbName=None): - """Returns a list of names of tables in the database `dbName`. + """Returns a list of names of tables in the database ``dbName``. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> "table1" in sqlCtx.tableNames() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlCtx.tableNames("db") + >>> "table1" in sqlContext.tableNames("db") True """ if dbName is None: @@ -585,22 +545,18 @@ def clearCache(self): class HiveContext(SQLContext): - """A variant of Spark SQL that integrates with data stored in Hive. - Configuration for Hive is read from hive-site.xml on the classpath. + Configuration for Hive is read from ``hive-site.xml`` on the classpath. It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. """ def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ SQLContext.__init__(self, sparkContext) - if hiveContext: self._scala_HiveContext = hiveContext @@ -618,6 +574,27 @@ def _ssql_ctx(self): def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) + def refreshTable(self, tableName): + """Invalidate and refresh all the cached the metadata of the given + table. For performance reasons, Spark SQL or the external data source + library it uses might cache certain metadata about a table, such as the + location of blocks. When those change outside of Spark SQL, users should + call this function to invalidate the cache. + """ + self._ssql_ctx.refreshTable(tableName) + + +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlContext): + self.sqlContext = sqlContext + + def register(self, name, f, returnType=StringType()): + return self.sqlContext.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + def _test(): import doctest @@ -627,13 +604,12 @@ def _test(): globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) - _monkey_patch_RDD(sqlCtx) globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d51309f7ef5aa..f2c3b74a185cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,12 +31,11 @@ from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] class DataFrame(object): - - """A collection of rows that have the same columns. + """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: @@ -50,13 +49,6 @@ class DataFrame(object): ageCol = people.age - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - A more concrete example:: # To create DataFrame using SQLContext @@ -76,9 +68,7 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): - """ - Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row` s. + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() @@ -93,8 +83,16 @@ def applySchema(it): return self._lazy_rdd + @property + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + def toJSON(self, use_unicode=False): - """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' @@ -103,16 +101,16 @@ def toJSON(self, use_unicode=False): return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. + """Saves the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a :class:`DataFrame` using the L{SQLContext.parquetFile} method. + a :class:`DataFrame` using :func:`SQLContext.parquetFile`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) True """ @@ -121,18 +119,18 @@ def saveAsParquetFile(self, path): def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") + >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ self._jdf.registerTempTable(name) def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" + """DEPRECATED: use :func:`registerTempTable` instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) @@ -163,22 +161,19 @@ def _java_save_mode(self, mode): return jmode def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source as a table. + """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended \ - to existing table. - * overwrite: Data in the existing table is expected to be overwritten by \ - the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the \ - :class:`DataFrame` and to not change the existing table. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", @@ -191,18 +186,17 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. - * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of \ - the :class:`DataFrame` and to not change the existing data. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if path is not None: options["path"] = path @@ -216,8 +210,7 @@ def save(self, path=None, source=None, mode="error", **options): @property def schema(self): - """Returns the schema of this :class:`DataFrame` (represented by - a L{StructType}). + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) @@ -238,11 +231,9 @@ def printSchema(self): print (self._jdf.schema().treeString()) def explain(self, extended=False): - """ - Prints the plans (logical and physical) to the console for - debugging purpose. + """Prints the (logical and physical) plans to the console for debugging purpose. - If extended is False, only prints the physical plan. + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... @@ -264,15 +255,13 @@ def explain(self, extended=False): print self._jdf.queryExecution().executedPlan().toString() def isLocal(self): - """ - Returns True if the `collect` and `take` methods can be run locally + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() def show(self, n=20): - """ - Print the first n rows. + """Prints the first ``n`` rows to the console. >>> df DataFrame[age: int, name: string] @@ -287,11 +276,7 @@ def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. + """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2L @@ -299,10 +284,7 @@ def count(self): return self._jdf.count() def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -314,7 +296,7 @@ def collect(self): return [cls(r) for r in rs] def limit(self, num): - """Limit the result count to the number specified. + """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] @@ -325,10 +307,7 @@ def limit(self, num): return DataFrame(jdf, self.sql_ctx) def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -336,9 +315,9 @@ def take(self, num): return self.limit(num).collect() def map(self, f): - """ Return a new RDD by applying a function to each Row + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. - It's a shorthand for df.rdd.map() + This is a shorthand for ``df.rdd.map()``. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] @@ -346,10 +325,10 @@ def map(self, f): return self.rdd.map(f) def flatMap(self, f): - """ Return a new RDD by first applying a function to all elements of this, + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. - It's a shorthand for df.rdd.flatMap() + This is a shorthand for ``df.rdd.flatMap()``. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] @@ -357,10 +336,9 @@ def flatMap(self, f): return self.rdd.flatMap(f) def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. - It's a shorthand for df.rdd.mapPartitions() + This is a shorthand for ``df.rdd.mapPartitions()``. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 @@ -370,10 +348,9 @@ def mapPartitions(self, f, preservesPartitioning=False): return self.rdd.mapPartitions(f, preservesPartitioning) def foreach(self, f): - """ - Applies a function to all rows of this DataFrame. + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreach() + This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print person.name @@ -382,10 +359,9 @@ def foreach(self, f): return self.rdd.foreach(f) def foreachPartition(self, f): - """ - Applies a function to each partition of this DataFrame. + """Applies the ``f`` function to each partition of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreachPartition() + This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: @@ -395,14 +371,14 @@ def foreachPartition(self, f): return self.rdd.foreachPartition(f) def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations + """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). @@ -413,7 +389,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): return self def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. """ self.is_cached = False @@ -425,8 +401,7 @@ def unpersist(self, blocking=True): # return DataFrame(rdd, self.sql_ctx) def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. >>> df.repartition(10).rdd.getNumPartitions() 10 @@ -434,8 +409,7 @@ def repartition(self, numPartitions): return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) def distinct(self): - """ - Return a new :class:`DataFrame` containing the distinct rows in this DataFrame. + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2L @@ -443,8 +417,7 @@ def distinct(self): return DataFrame(self._jdf.distinct(), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. + """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() 1L @@ -456,7 +429,7 @@ def sample(self, withReplacement, fraction, seed=None): @property def dtypes(self): - """Return all column names and their data types as a list. + """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] @@ -465,7 +438,7 @@ def dtypes(self): @property def columns(self): - """ Return all column names as a list. + """Returns all column names as a list. >>> df.columns [u'age', u'name'] @@ -473,16 +446,17 @@ def columns(self): return [f.name for f in self.schema.fields] def join(self, other, joinExprs=None, joinType=None): - """ - Join with another :class:`DataFrame`, using the given join expression. - The following performs a full outer join between `df1` and `df2`. + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + :param joinType: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: @@ -497,9 +471,9 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column(s). + """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: The columns or expressions used for sorting + :param cols: list of :class:`Column` to sort by. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] @@ -540,7 +514,9 @@ def describe(self, *cols): return DataFrame(jdf, self.sql_ctx) def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. + """ + Returns the first ``n`` rows as a list of :class:`Row`, + or the first :class:`Row` if ``n`` is ``None.`` >>> df.head() Row(age=2, name=u'Alice') @@ -553,7 +529,7 @@ def head(self, n=None): return self.take(n) def first(self): - """ Return the first row. + """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') @@ -561,7 +537,7 @@ def first(self): return self.head() def __getitem__(self, item): - """ Return the column by given name + """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] @@ -581,7 +557,7 @@ def __getitem__(self, item): raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): - """ Return the column by given name + """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] @@ -592,7 +568,11 @@ def __getattr__(self, name): return Column(jc) def select(self, *cols): - """ Selecting a set of expressions. + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -607,9 +587,9 @@ def select(self, *cols): return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): - """ - Selects a set of SQL expressions. This is a variant of - `select` that accepts SQL expressions. + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] @@ -619,10 +599,12 @@ def selectExpr(self, *expr): return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition, which could be - :class:`Column` expression or string of SQL expression. + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. - where() is an alias for filter(). + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] @@ -645,16 +627,19 @@ def filter(self, condition): where = filter def groupBy(self, *cols): - """ Group the :class:`DataFrame` using the specified columns, + """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) @@ -663,7 +648,7 @@ def groupBy(self, *cols): def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups - (shorthand for df.groupBy.agg()). + (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(MAX(age)=5)] @@ -697,8 +682,104 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + + This is an alias for ``na.drop()``. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.dropna().show() + age height name + 10 80 Alice + + >>> df4.na.drop().show() + age height name + 10 80 Alice + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + + def fillna(self, value, subset=None): + """Replace null values, alias for ``na.fill()``. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.fillna(50).show() + age height name + 10 80 Alice + 5 50 Bob + 50 50 Tom + 50 50 null + + >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + def withColumn(self, colName, col): - """ Return a new :class:`DataFrame` by adding a column. + """Returns a new :class:`DataFrame` by adding a column. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] @@ -706,7 +787,10 @@ def withColumn(self, colName, col): return self.select('*', col.alias(colName)) def withColumnRenamed(self, existing, new): - """ Rename an existing column to a new name + """REturns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] @@ -717,8 +801,9 @@ def withColumnRenamed(self, existing, new): return self.select(*cols) def toPandas(self): - """ - Collect all the rows and return a `pandas.DataFrame`. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name @@ -731,8 +816,7 @@ def toPandas(self): # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame + """SchemaRDD is deprecated, please use :class:`DataFrame`. """ @@ -759,10 +843,9 @@ def _api(self, *args): class GroupedData(object): - """ A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). + created by :func:`DataFrame.groupBy`. """ def __init__(self, jdf, sql_ctx): @@ -770,22 +853,25 @@ def __init__(self, jdf, sql_ctx): self.sql_ctx = sql_ctx def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - :param exprs: list or aggregate columns or a map from column - name to aggregate methods. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=5), Row(MIN(age)=2)] + [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -802,7 +888,7 @@ def agg(self, *exprs): @dfapi def count(self): - """ Count the number of rows for each group. + """Counts the number of records for each group. >>> df.groupBy(df.age).count().collect() [Row(age=2, count=1), Row(age=5, count=1)] @@ -810,8 +896,11 @@ def count(self): @df_varargs_api def mean(self, *cols): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() [Row(AVG(age)=3.5)] @@ -821,8 +910,11 @@ def mean(self, *cols): @df_varargs_api def avg(self, *cols): - """Compute the average value for each numeric columns - for each group. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() [Row(AVG(age)=3.5)] @@ -832,8 +924,7 @@ def avg(self, *cols): @df_varargs_api def max(self, *cols): - """Compute the max value for each numeric columns for - each group. + """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() [Row(MAX(age)=5)] @@ -843,8 +934,9 @@ def max(self, *cols): @df_varargs_api def min(self, *cols): - """Compute the min value for each numeric column for - each group. + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() [Row(MIN(age)=2)] @@ -854,8 +946,9 @@ def min(self, *cols): @df_varargs_api def sum(self, *cols): - """Compute the sum for each numeric columns for each - group. + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() [Row(SUM(age)=7)] @@ -1061,6 +1154,24 @@ def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext @@ -1069,13 +1180,19 @@ def _test(): globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5873f09ae3275..daeb6916b58bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -76,7 +76,7 @@ def _(col): def countDistinct(col, *cols): - """ Return a new Column for distinct count of `col` or `cols` + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] @@ -91,7 +91,7 @@ def countDistinct(col, *cols): def approxCountDistinct(col, rsd=None): - """ Return a new Column for approximate distinct count of `col` + """Returns a new :class:`Column` for approximate distinct count of ``col``. >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] @@ -123,7 +123,8 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf @@ -142,7 +143,7 @@ def __call__(self, *cols): def udf(f, returnType=StringType()): - """Create a user defined function (UDF) + """Creates a :class:`Column` expression representing a user defined function (UDF). >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -160,7 +161,7 @@ def _test(): globs = pyspark.sql.functions.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2720439416682..b3a6a2c6a9229 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,35 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() @@ -415,6 +446,102 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0169028ccc4eb..ef76d84c00481 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -33,8 +33,7 @@ class DataType(object): - - """Spark SQL DataType""" + """Base class for data types.""" def __repr__(self): return self.__class__.__name__ @@ -67,7 +66,6 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" _instances = {} @@ -79,66 +77,45 @@ def __call__(cls): class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton class NullType(PrimitiveType): + """Null type. - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. + The data type representing None, used for the types that cannot be inferred. """ class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. + """String data type. """ class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. + """Binary (byte array) data type. """ class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. + """Boolean data type. """ class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. + """Date (datetime.date) data type. """ class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. + """Timestamp (datetime.datetime) data type. """ class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. + """Decimal (decimal.Decimal) data type. """ def __init__(self, precision=None, scale=None): @@ -166,80 +143,55 @@ def __repr__(self): class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. + """Double data type, representing double precision floats. """ class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. + """Float data type, representing single precision floats. """ class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. + """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. + """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' class LongType(PrimitiveType): + """Long data type, i.e. a signed 64-bit integer. - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. """ def simpleString(self): return 'bigint' class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. + """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): return 'smallint' class ArrayType(DataType): + """Array data type. - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. """ def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - + """ >>> ArrayType(StringType()) == ArrayType(StringType(), True) True >>> ArrayType(StringType(), False) == ArrayType(StringType()) @@ -268,29 +220,17 @@ def fromJson(cls, json): class MapType(DataType): + """Map data type. - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + Keys in a map data type are not allowed to be null (None). """ def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - + """ >>> (MapType(StringType(), IntegerType()) ... == MapType(StringType(), IntegerType(), True)) True @@ -325,30 +265,16 @@ def fromJson(cls, json): class StructField(DataType): + """A field in :class:`StructType`. - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be serialized to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - + """ >>> (StructField("f1", StringType(), True) ... == StructField("f1", StringType(), True)) True @@ -384,17 +310,13 @@ def fromJson(cls, json): class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - + This is the data type representing a :class:`Row`. """ def __init__(self, fields): - """Creates a StructType - + """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 @@ -425,9 +347,9 @@ def fromJson(cls, json): class UserDefinedType(DataType): - """ + """User-defined type (UDT). + .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). """ @classmethod @@ -512,7 +434,7 @@ def _parse_datatype_json_string(json_string): >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_primitive_types.values(): @@ -645,8 +567,8 @@ def _infer_schema(row): elif isinstance(row, (tuple, list)): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) + elif hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) @@ -725,7 +647,7 @@ def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields") or hasattr(obj, "__fields__"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) @@ -1075,12 +997,13 @@ def _restore_object(dataType, obj): # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) - if cls is None: + if cls is None or cls.__datatype is not dataType: # use dataType as key to avoid create multiple class cls = _cached_cls.get(dataType) if cls is None: cls = _create_cls(dataType) _cached_cls[dataType] = cls + cls.__datatype = dataType _cached_cls[k] = cls return cls(obj) @@ -1197,8 +1120,8 @@ def Dict(d): class Row(tuple): """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) + __datatype = dataType + __fields__ = tuple(f.name for f in dataType.fields) __slots__ = () # create property for fast access @@ -1206,22 +1129,22 @@ class Row(tuple): def asDict(self): """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) + return dict((n, getattr(self, n)) for n in self.__fields__) def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__fields__)) def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) + return (_restore_object, (self.__datatype, tuple(self))) return Row def _create_row(fields, values): row = Row(*values) - row.__FIELDS__ = fields + row.__fields__ = fields return row @@ -1261,7 +1184,7 @@ def __new__(self, *args, **kwargs): # create row objects names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) - row.__FIELDS__ = names + row.__fields__ = names return row else: @@ -1271,11 +1194,11 @@ def asDict(self): """ Return as an dict """ - if not hasattr(self, "__FIELDS__"): + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) + return dict(zip(self.__fields__, self)) - # let obect acs like class + # let object acts like class def __call__(self, *args): """create new Row object""" return _create_row(self, args) @@ -1286,21 +1209,21 @@ def __getattr__(self, item): try: # it will be slow when it has many fields, # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) + idx = self.__fields__.index(item) return self[idx] except IndexError: raise AttributeError(item) def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) + if hasattr(self, "__fields__"): + return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self) def __repr__(self): - if hasattr(self, "__FIELDS__"): + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__fields__, tuple(self))) else: return "" % ", ".join(self) @@ -1315,7 +1238,7 @@ def _test(): globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 608f8e26473a6..9b4635e49020b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -23,13 +23,16 @@ import tempfile import struct +from py4j.java_collections import MapConverter + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext +from pyspark.streaming.kafka import KafkaUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 20 # seconds duration = 1 def setUp(self): @@ -556,5 +559,43 @@ def check_output(n): check_output(3) +class KafkaStreamTests(PySparkStreamingTestCase): + + def setUp(self): + super(KafkaStreamTests, self).setUp() + + kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") + self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils.setup() + + def tearDown(self): + if self._kafkaTestUtils is not None: + self._kafkaTestUtils.teardown() + self._kafkaTestUtils = None + + super(KafkaStreamTests, self).tearDown() + + def test_kafka_stream(self): + """Test the Python Kafka stream API.""" + topic = "topic1" + sendData = {"a": 3, "b": 5, "c": 10} + jSendData = MapConverter().convert(sendData, + self.ssc.sparkContext._gateway._gateway_client) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, jSendData) + + stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), + "test-streaming-consumer", {topic: 1}, + {"auto.offset.reset": "smallest"}) + + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dd8d3b1c53733..ee67e80d539f8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,9 +31,12 @@ import time import zipfile import random +import itertools import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -76,7 +79,7 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -108,7 +111,7 @@ def test_small_dataset(self): sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 30) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), @@ -124,10 +127,36 @@ def test_huge_dataset(self): m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.iteritems()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -521,10 +550,8 @@ def test_large_closure(self): data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) @@ -702,6 +729,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "5m") + N = 200001 + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda (k, vs): k == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) + self.assertEqual([(N/3, N/3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N/3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -752,9 +794,9 @@ def test_narrow_dependency_in_join(self): self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - self.sc.setJobGroup("test1", "test", True) tracker = self.sc.statusTracker() + self.sc.setJobGroup("test1", "test", True) d = sorted(parted.join(parted).collect()) self.assertEqual(10, len(d)) self.assertEqual((0, (0, 0)), d[0]) @@ -787,6 +829,17 @@ def test_take_on_jrdd(self): rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) rdd._jrdd.first() + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + class ProfilerTests(PySparkTestCase): @@ -1441,6 +1494,20 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = sys.version_info + sys.version_info = (2, 0, 0) + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + try: + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + sys.version_info = version + log4j.LogManager.getRootLogger().setLevel(old_level) + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8a93c320ec5d3..452d6fabdcc17 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -88,7 +88,11 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer) = command + (func, profiler, deserializer, serializer), version = command + if version != sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + (sys.version_info[:2], version)) init_time = time.time() def process(): diff --git a/python/run-tests b/python/run-tests index b7630c356cfae..f3a07d8aba562 100755 --- a/python/run-tests +++ b/python/run-tests @@ -21,6 +21,8 @@ # Figure out where the Spark framework is installed FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +. "$FWDIR"/bin/load-spark-env.sh + # CD into the python directory to find things on the right path cd "$FWDIR/python" @@ -57,7 +59,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -77,6 +79,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" @@ -96,6 +99,21 @@ function run_ml_tests() { function run_streaming_tests() { echo "Run streaming tests ..." + + KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly + JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" + for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 + echo "You need to build Spark with " \ + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ + "'build/mvn package' before running this program" 1>&2 + exit 1 + fi + KAFKA_ASSEMBLY_JAR="$f" + done + + export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e7e4a4113174a..e2ee9c963a4da 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 92e76a3fe6ca2..d8e0facb81169 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " # if no args specified, show usage if [ $# -le 1 ]; then @@ -195,6 +195,23 @@ case $option in fi ;; + (status) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo $command is running. + exit 0 + else + echo $pid file is present but $command not running + exit 1 + fi + else + echo $command not running. + exit 2 + fi + ;; + (*) echo $usage exit 1 diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 2fc35309f4ca5..4c919ff76a8f5 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -17,10 +17,69 @@ # limitations under the License. # -# Usage: start-slave.sh -# where is like "spark://localhost:7077" +# Starts a slave on the machine this script is executed on. +# +# Environment Variables +# +# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# slave. Default is 1. +# SPARK_WORKER_PORT The base port number for the first worker. If set, +# subsequent workers will increment this number. If +# unset, Spark will find a valid port number, but +# with no guarantee of a predictable pattern. +# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first +# worker. Subsequent workers will increment this +# number. Default is 8081. + +usage="Usage: start-slave.sh where is like spark://localhost:7077" + +if [ $# -lt 1 ]; then + echo $usage + echo Called as start-slave.sh $* + exit 1 +fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +# First argument should be the master; we need to store it aside because we may +# need to insert arguments between it and the other arguments +MASTER=$1 +shift + +# Determine desired worker port +if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then + SPARK_WORKER_WEBUI_PORT=8081 +fi + +# Start up the appropriate number of workers on this machine. +# quick local function to start a worker +function start_instance { + WORKER_NUM=$1 + shift + + if [ "$SPARK_WORKER_PORT" = "" ]; then + PORT_FLAG= + PORT_NUM= + else + PORT_FLAG="--port" + PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) + fi + WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) + + "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" +} + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + start_instance 1 "$@" +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + start_instance $(( 1 + $i )) "$@" + done +fi + diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 76316a3067c93..24d6268815ed3 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,6 +17,8 @@ # limitations under the License. # +# Starts a slave instance on each machine specified in the conf/slaves file. + sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" @@ -57,13 +59,4 @@ if [ "$START_TACHYON" == "true" ]; then fi # Launch the slaves -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" -else - if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then - SPARK_WORKER_WEBUI_PORT=8081 - fi - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh new file mode 100755 index 0000000000000..3d1da5b254f2a --- /dev/null +++ b/sbin/stop-slave.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to stop all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: stop-slave.sh +# Stops all slaves on this worker machine + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 7c2201100ef97..54c9bd46803a9 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -29,10 +29,4 @@ if [ -e "$sbin"/../tachyon/bin/tachyon ]; then "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 459a5035d4984..7168d5b2a8e26 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,7 +137,7 @@ - + diff --git a/sql/README.md b/sql/README.md index fbb3200a3a4b4..237620e3fa808 100644 --- a/sql/README.md +++ b/sql/README.md @@ -56,6 +56,6 @@ res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,v You can also build further queries on top of these `DataFrames` using the query DSL. ``` -scala> query.where('key > 30).select(avg('key)).collect() +scala> query.where(query("key") > 30).select(avg(query("key"))).collect() res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 34fedead44db3..f9992185a4563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,7 +30,7 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { - def withPosition(line: Option[Int], startPosition: Option[Int]) = { + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { val newException = new AnalysisException(message, line, startPosition) newException.setStackTrace(getStackTrace) newException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..ac8a782976465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.StructType object Row { /** @@ -257,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala new file mode 100644 index 0000000000000..d4f9fdacda4fb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.util.{Map => JavaMap} + +import scala.collection.mutable.HashMap + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Functions to convert Scala types to Catalyst types and vice versa. + */ +object CatalystTypeConverters { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => + udt.serialize(obj) + + case (o: Option[_], _) => + o.map(convertToCatalyst(_, dataType)).orNull + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToCatalyst(_, arrayType.elementType)) + + case (s: Array[_], arrayType: ArrayType) => + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + + case (jmap: JavaMap[_, _], mapType: MapType) => + val iter = jmap.entrySet.iterator + var listOfEntries: List[(Any, Any)] = List() + while (iter.hasNext) { + val entry = iter.next() + listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), + convertToCatalyst(entry.getValue, mapType.valueType)) + } + listOfEntries.toMap + + case (p: Product, structType: StructType) => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case (d: String, _) => + UTF8String(d) + + case (d: BigDecimal, _) => + Decimal(d) + + case (d: java.math.BigDecimal, _) => + Decimal(d) + + case (d: java.sql.Date, _) => + DateUtils.fromJavaDate(d) + + case (r: Row, structType: StructType) => + val converters = structType.fields.map { + f => (item: Any) => convertToCatalyst(item, f.dataType) + } + convertRowWithConverters(r, structType, converters) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def extractOption(item: Any): Any = item match { + case opt: Option[_] => opt.orNull + case other => other + } + + dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item) => extractOption(item) match { + case null => null + case other => udt.serialize(other) + } + + case arrayType: ArrayType => + val elementConverter = createToCatalystConverter(arrayType.elementType) + (item: Any) => { + extractOption(item) match { + case a: Array[_] => a.toSeq.map(elementConverter) + case s: Seq[_] => s.map(elementConverter) + case null => null + } + } + + case mapType: MapType => + val keyConverter = createToCatalystConverter(mapType.keyType) + val valueConverter = createToCatalystConverter(mapType.valueType) + (item: Any) => { + extractOption(item) match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) + } + convertedMap + + case null => null + } + } + + case structType: StructType => + val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) + (item: Any) => { + extractOption(item) match { + case r: Row => + convertRowWithConverters(r, structType, converters) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx)(iter.next()) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case null => + null + } + } + + case dateType: DateType => (item: Any) => extractOption(item) match { + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case other => other + } + + case dataType: StringType => (item: Any) => extractOption(item) match { + case s: String => UTF8String(s) + case other => other + } + + case _ => + (item: Any) => extractOption(item) match { + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case other => other + } + } + } + + /** + * Converts Scala objects to catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => + udt.deserialize(d) + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToScala(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + + case (r: Row, s: StructType) => + convertRowToScala(r, s) + + case (d: Decimal, _: DecimalType) => + d.toJavaBigDecimal + + case (i: Int, DateType) => + DateUtils.toJavaDate(i) + + case (s: UTF8String, StringType) => + s.toString() + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item: Any) => if (item == null) null else udt.deserialize(item) + + case arrayType: ArrayType => + val elementConverter = createToScalaConverter(arrayType.elementType) + (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) + + case mapType: MapType => + val keyConverter = createToScalaConverter(mapType.keyType) + val valueConverter = createToScalaConverter(mapType.valueType) + (item: Any) => if (item == null) { + null + } else { + item.asInstanceOf[Map[_, _]].map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + } + + case s: StructType => + val converters = s.fields.map(f => createToScalaConverter(f.dataType)) + (item: Any) => { + if (item == null) { + null + } else { + convertRowWithConverters(item.asInstanceOf[Row], s, converters) + } + } + + case _: DecimalType => + (item: Any) => item match { + case d: Decimal => d.toJavaBigDecimal + case other => other + } + + case DateType => + (item: Any) => item match { + case i: Int => DateUtils.toJavaDate(i) + case other => other + } + + case StringType => + (item: Any) => item match { + case s: UTF8String => s.toString() + case other => other + } + + case other => + (item: Any) => item + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + val ar = new Array[Any](r.size) + var idx = 0 + while (idx < r.size) { + ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } + + /** + * Converts a row by applying the provided set of converter functions. It is used for both + * toScala and toCatalyst conversions. + */ + private[sql] def convertRowWithConverters( + row: Row, + schema: StructType, + converters: Array[Any => Any]): Row = { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx)(row(idx)) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d6126c24fc50d..d9521953cad73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -46,56 +46,6 @@ trait ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** - * Converts Scala objects to catalyst rows / types. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) - case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) { - s.toSeq - } else { - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) - } - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } - case (p: Product, structType: StructType) => - new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (d: BigDecimal, _) => Decimal(d) - case (d: java.math.BigDecimal, _) => Decimal(d) - case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) - case (other, _) => other - } - - /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => udt.deserialize(d) - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - case (r: Row, s: StructType) => convertRowToScala(r, s) - case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal - case (i: Int, DateType) => DateUtils.toJavaDate(i) - case (other, _) => other - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - // TODO: This is very slow!!! - new GenericRowWithSchema( - r.toSeq.zip(schema.fields.map(_.dataType)) - .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema) - } - /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => @@ -179,6 +129,8 @@ trait ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case other => + throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } @@ -186,6 +138,7 @@ trait ScalaReflection { // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType + case obj: String => StringType case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ea7d44a3723d1..0af969cc5cc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -111,6 +111,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") + protected val WITH = Keyword("WITH") protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -120,13 +121,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } protected lazy val start: Parser[LogicalPlan] = - ( (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - | insert + start1 | insert | cte + + protected lazy val start1: Parser[LogicalPlan] = + (select | ("(" ~> select <~ ")")) * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) protected lazy val select: Parser[LogicalPlan] = @@ -139,7 +141,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { sortType.? ~ (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) + val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) @@ -153,7 +155,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val insert: Parser[LogicalPlan] = INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o) + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) + } + + protected lazy val cte: Parser[LogicalPlan] = + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { + case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) } protected lazy val projection: Parser[Expression] = @@ -316,13 +323,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } - | NULL ^^^ Literal(null, NullType) + | stringLit ^^ {case s => Literal.create(s, StringType) } + | NULL ^^^ Literal.create(null, NullType) ) protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) + ( TRUE ^^^ Literal.create(true, BooleanType) + | FALSE ^^^ Literal.create(false, BooleanType) ) protected lazy val numericLiteral: Parser[Literal] = @@ -374,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ UnresolvedAttribute + | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) + case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44eceb0b372e6..cb49e5ad5586f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyser needs only to resolve attribute + * when all relations are already filled in and the analyzer needs only to resolve attribute * references. */ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) @@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and * a [[FunctionRegistry]]. */ -class Analyzer(catalog: Catalog, - registry: FunctionRegistry, - caseSensitive: Boolean, - maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { +class Analyzer( + catalog: Catalog, + registry: FunctionRegistry, + caseSensitive: Boolean, + maxIterations: Int = 100) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution @@ -139,10 +140,10 @@ class Analyzer(catalog: Catalog, case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal(null, expr.dataType) + Literal.create(null, expr.dataType) case x if x == g.gid => // replace the groupingId with concrete value (the bit mask) - Literal(bitmask, IntegerType) + Literal.create(bitmask, IntegerType) }) result += GroupExpression(substitution) @@ -168,21 +169,36 @@ class Analyzer(catalog: Catalog, * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation): LogicalPlan = { + def getTable(u: UnresolvedRelation, cteRelations: Map[String, LogicalPlan]): LogicalPlan = { try { - catalog.lookupRelation(u.tableIdentifier, u.alias) + // In hive, if there is same table name in database and CTE definition, + // hive will use the table in database, not the CTE one. + // Taking into account the reasonableness and the implementation complexity, + // here use the CTE definition first, check table name only and ignore database name + cteRelations.get(u.tableIdentifier.last) + .map(relation => u.alias.map(Subquery(_, relation)).getOrElse(relation)) + .getOrElse(catalog.lookupRelation(u.tableIdentifier, u.alias)) } catch { case _: NoSuchTableException => u.failAnalysis(s"no such table ${u.tableName}") } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) => - i.copy( - table = EliminateSubQueries(getTable(u))) - case u: UnresolvedRelation => - getTable(u) + def apply(plan: LogicalPlan): LogicalPlan = { + val (realPlan, cteRelations) = plan match { + // TODO allow subquery to define CTE + // Add cte table to a temp relation map,drop `with` plan and keep its child + case With(child, relations) => (child, relations) + case other => (other, Map.empty[String, LogicalPlan]) + } + + realPlan transform { + case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + i.copy( + table = EliminateSubQueries(getTable(u, cteRelations))) + case u: UnresolvedRelation => + getTable(u, cteRelations) + } } } @@ -212,6 +228,12 @@ class Analyzer(catalog: Catalog, case o => o :: Nil } Alias(c.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateStruct(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -252,7 +274,15 @@ class Analyzer(catalog: Catalog, case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - }.head // Only handle first case found, others will be fixed on the next pass. + }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. + sys.error( + s""" + |Failure when resolving conflicting references in Join: + |$plan + | + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + """.stripMargin) + } val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -267,18 +297,19 @@ class Analyzer(catalog: Catalog, case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && + resolver(nameParts(0), VirtualColumn.groupingIdName) && q.isInstanceOf[GroupingAnalytics] => // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid - case u @ UnresolvedAttribute(name) => + case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => - resolveGetField(child, fieldName) + GetField(child, fieldName, resolver) } } @@ -298,36 +329,6 @@ class Analyzer(catalog: Catalog, */ protected def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - */ - protected def resolveGetField(expr: Expression, fieldName: String): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** @@ -340,19 +341,16 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = - AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a })) + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) - val missingInProject = requiredAttributes -- p.output - if (missingInProject.nonEmpty) { + // If this rule was not a no-op, return the transformed plan, otherwise return the original. + if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList.map(_.toAttribute), - Sort(ordering, global, - Project(projectList ++ missingInProject, child))) + Project(p.output, + Sort(resolvedOrdering, global, + Project(projectList ++ missing, child))) } else { - logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") + logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) @@ -364,18 +362,54 @@ class Analyzer(catalog: Catalog, grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) - val missingInAggs = resolved.filterNot(a.outputSet.contains) - logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") - if (missingInAggs.nonEmpty) { + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(resolvedOrdering, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } } + + /** + * Given a child and a grandchild that are present beneath a sort operator, returns + * a resolved sort ordering and a list of attributes that are missing from the child + * but are present in the grandchild. + */ + def resolveAndFindMissing( + ordering: Seq[SortOrder], + child: LogicalPlan, + grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + // Find any attributes that remain unresolved in the sort. + val unresolved: Seq[Seq[String]] = + ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) + + // Create a map from name, to resolved attributes, when the desired name can be found + // prior to the projection. + val resolved: Map[Seq[String], NamedExpression] = + unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap + + // Construct a set that contains all of the attributes that we need to evaluate the + // ordering. + val requiredAttributes = AttributeSet(resolved.values) + + // Figure out which ones are missing from the projection, so that we can add them and + // remove them after the sort. + val missingInProject = requiredAttributes -- child.output + + // Now that we have all the attributes we need, reconstruct a resolved ordering. + // It is important to do it here, instead of waiting for the standard resolved as adding + // attributes to the project below can actually introduce ambiquity that was not present + // before. + val resolvedOrdering = ordering.map(_ transform { + case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) + }).asInstanceOf[Seq[SortOrder]] + + (resolvedOrdering, missingInProject.toSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5eb7dff0cede8..b2f8157a1a61f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** - * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception + * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception * as an AnalysisException with the correct position information. */ class NoSuchTableException extends Exception @@ -201,7 +201,7 @@ trait OverrideCatalog extends Catalog { /** * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyser needs only to resolve attribute references. + * relations are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 40472a1cbb3b4..1155dac28fc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.types._ /** * Throws user facing errors when passed invalid queries that fail to analyze. */ -class CheckAnalysis { +trait CheckAnalysis { + self: Analyzer => /** * Override to provide additional checks for correct analysis. @@ -33,17 +34,26 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String): Nothing = { + protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } - def apply(plan: LogicalPlan): Unit = { + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => + if (operator.childrenResolved) { + val nameParts = a match { + case UnresolvedAttribute(nameParts) => nameParts + case _ => Seq(a.name) + } + // Throw errors for specific problems with get field. + operator.resolveChildren(nameParts, resolver, throwErrors = true) + } + val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c43ea55899695..16ca5bcd57a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -57,8 +57,8 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr } /** - * A trivial catalog that returns an error when a function is requested. Used for testing when all - * functions are already filled in and the analyser needs only to resolve attribute references. + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { override def registerFunction(name: String, builder: FunctionBuilder): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..35c7f00d4e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -78,6 +78,7 @@ trait HiveTypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: Division :: + PropagateTypes :: Nil /** @@ -114,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -284,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -295,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -310,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -327,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip - // SUM and AVERAGE are handled by the implementations of those expressions + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** @@ -504,6 +563,10 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 894c3500cf533..35b74024a4cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -30,5 +30,5 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance(): this.type + def newInstance(): LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c61c395cb4bb1..7731336d247db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -44,7 +44,7 @@ package object analysis { } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ - def withPosition[A](t: TreeNode[_])(f: => A) = { + def withPosition[A](t: TreeNode[_])(f: => A): A = { try f catch { case a: AnalysisException => throw a.withPosition(t.origin.line, t.origin.startPosition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 300e9ba187bc5..3f567e3e8b2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,7 +49,12 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) + extends Attribute with trees.LeafNode[Expression] { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this - override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } +object UnresolvedAttribute { + def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) +} + case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 145f062dd6817..21c15ad14fd19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -293,7 +293,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 11b4eb5c888be..5345696570b41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -34,7 +34,7 @@ object AttributeSet { def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]): AttributeSet = { + def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 31f1a5fdc7e53..adf941ab2a45f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Int](_, d => DateUtils.toString(d)) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s)) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 389dc4f745723..9a77ca624ebe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType /** @@ -39,12 +39,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) - s""" case $x => + s"""case $x => val func = function.asInstanceOf[($anys) => Any] $childs + $converters (input: Row) => { func( $evals) @@ -60,51 +62,61 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (input: Row) => { func() } - + case 1 => val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + converter0(child0.eval(input))) } - + case 2 => val func = function.asInstanceOf[(Any, Any) => Any] val child0 = children(0) val child1 = children(1) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input))) } - + case 3 => val func = function.asInstanceOf[(Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input))) } - + case 4 => val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) val child3 = children(3) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input))) } - + case 5 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -112,15 +124,20 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child2 = children(2) val child3 = children(3) val child4 = children(4) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input))) } - + case 6 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -129,16 +146,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child3 = children(3) val child4 = children(4) val child5 = children(5) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input))) } - + case 7 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -148,17 +171,24 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child4 = children(4) val child5 = children(5) val child6 = children(6) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input))) } - + case 8 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -169,18 +199,26 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child5 = children(5) val child6 = children(6) val child7 = children(7) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input))) } - + case 9 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -192,19 +230,28 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child6 = children(6) val child7 = children(7) val child8 = children(8) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input))) } - + case 10 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -217,20 +264,30 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child7 = children(7) val child8 = children(8) val child9 = children(9) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input))) } - + case 11 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -244,21 +301,32 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child8 = children(8) val child9 = children(9) val child10 = children(10) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input))) } - + case 12 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -273,22 +341,34 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child9 = children(9) val child10 = children(10) val child11 = children(11) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input))) } - + case 13 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -304,23 +384,36 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child10 = children(10) val child11 = children(11) val child12 = children(12) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input))) } - + case 14 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -337,24 +430,38 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child11 = children(11) val child12 = children(12) val child13 = children(13) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input))) } - + case 15 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -372,25 +479,40 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child12 = children(12) val child13 = children(13) val child14 = children(14) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input))) } - + case 16 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -409,26 +531,42 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child13 = children(13) val child14 = children(14) val child15 = children(15) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input))) } - + case 17 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -448,27 +586,44 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child14 = children(14) val child15 = children(15) val child16 = children(16) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input))) } - + case 18 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -489,28 +644,46 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child15 = children(15) val child16 = children(16) val child17 = children(17) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input))) } - + case 19 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -532,29 +705,48 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child16 = children(16) val child17 = children(17) val child18 = children(18) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input))) } - + case 20 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -577,30 +769,50 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child17 = children(17) val child18 = children(18) val child19 = children(19) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input))) } - + case 21 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -624,31 +836,52 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child18 = children(18) val child19 = children(19) val child20 = children(20) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input))) } - + case 22 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -673,35 +906,57 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child19 = children(19) val child20 = children(20) val child21 = children(21) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType), - ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input)), + converter21(child21.eval(input))) } } - + // scalastyle:on - - override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + + override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 47b6f358ed1b1..3475ed05f4454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { + values(ordinal).update(value) + } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) - override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 30da4faa3f1c6..14a855054b94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress override def children: Seq[Expression] = expressions override def nullable: Boolean = false - override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) + override def newInstance(): CollectHashSetFunction = + new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction( override def eval(input: Row): Any = seen.size.toLong } +/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ +private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { + + override def sqlType: DataType = BinaryType + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def serialize(obj: Any): Array[Byte] = + obj.asInstanceOf[HyperLogLog].getBytes + + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def deserialize(datum: Any): HyperLogLog = + HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) + + override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] +} + case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { override def nullable: Boolean = false - override def dataType: DataType = child.dataType + override def dataType: DataType = HyperLogLogUDT override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def newInstance(): ApproxCountDistinctPartitionFunction = { new ApproxCountDistinctPartitionFunction(child, this, relativeSD) @@ -505,7 +523,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private var count: Long = _ private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + private def addFunction(value: Any) = Add(sum, + Cast(Literal.create(value, expr.dataType), calcType)) override def eval(input: Row): Any = { if (count == 0L) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1f6526ef66c56..566b34f7c3a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } +case class MinOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable + + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 + } else { + if (ordering.compare(evalE1, evalE2) < 0) { + evalE1 + } else { + evalE2 + } + } + } + + override def toString: String = s"MinOf($left, $right)" +} + /** * A function that get the absolute value of the numeric value. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1abf3c0b64a5..be2c101d63a63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1: BinaryType, e2: BinaryType) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } @@ -464,7 +477,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val itemEval = expressionEvaluator(item) val setEval = expressionEvaluator(set) - val ArrayType(elementType, _) = set.dataType + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType itemEval.code ++ setEval.code ++ q""" @@ -482,7 +495,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val leftEval = expressionEvaluator(left) val rightEval = expressionEvaluator(right) - val ArrayType(elementType, _) = left.dataType + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType leftEval.code ++ rightEval.code ++ q""" @@ -524,6 +537,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + case UnscaledValue(child) => val childEval = expressionEvaluator(child) @@ -573,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -584,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -595,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { + case StringType => q"$destinationRow.update($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } @@ -618,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a8880..6f572ff959fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } val specificMutatorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3fd78db297462..fc1f69655963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types._ /** @@ -81,6 +83,41 @@ trait GetField extends UnaryExpression { def field: StructField } +object GetField { + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + def apply( + expr: Expression, + fieldName: String, + resolver: Resolver): GetField = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } +} + /** * Returns the value of fields in the Struct `child`. */ @@ -120,7 +157,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -142,3 +179,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def toString: String = s"Array(${children.mkString(",")})" } + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[NamedExpression]) extends Expression { + override type EvaluatedType = Row + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = children.map { child => + StructField(child.name, child.dataType, child.nullable, child.metadata) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: Row): EvaluatedType = { + Row(children.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..67caadb839ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ /** @@ -85,8 +85,11 @@ case class UserDefinedGenerator( override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(inputRow(input)) + function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 19f3fc9c2291a..18cba4cc46707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ object Literal { @@ -29,7 +30,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -41,6 +42,10 @@ object Literal { case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + + def create(v: Any, dataType: DataType): Literal = { + Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) + } } /** @@ -62,7 +67,10 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { +/** + * In order to do type checking, use Literal.create() instead of constructor + */ +case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d1f3d4f4ee9ee..f9161cf34f0c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -35,7 +35,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType: DataType = if (resolved) { + override def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -74,3 +74,26 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } } + +/** + * A predicate that is evaluated to be true if there are at least `n` non-null values. + */ +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = false + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: Row): Boolean = { + var numNonNulls = 0 + var i = 0 + while (i < childrenArray.length && numNonNulls < n) { + if (childrenArray(i).eval(input) != null) { + numNonNulls += 1 + } + i += 1 + } + numNonNulls >= n + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7e47cb3fffe12..fcd6352079b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.compare( - l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..b6ec7d3417ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} - +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -37,6 +36,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -114,9 +114,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - values(i).asInstanceOf[String] + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -189,8 +195,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } @@ -224,6 +229,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: NativeType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison } @@ -232,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def forSchema(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 35faa00782e80..4c44182278207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +/** The data type for expressions returning an OpenHashSet as the result. */ +private[sql] class OpenHashSetUDT( + val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { + + override def sqlType: DataType = ArrayType(elementType) + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def serialize(obj: Any): Seq[Any] = { + obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq + } + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def deserialize(datum: Any): OpenHashSet[Any] = { + val iterator = datum.asInstanceOf[Seq[Any]].iterator + val set = new OpenHashSet[Any] + while(iterator.hasNext) { + set.add(iterator.next()) + } + + set + } + + override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] + + private[spark] override def asNullable: OpenHashSetUDT = this +} + /** * Creates a new set of the specified type */ @@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def nullable: Boolean = false - // We are currently only using these Expressions internally for aggregation. However, if we ever - // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - override def dataType: DataType = ArrayType(elementType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) override def eval(input: Row): Any = { new OpenHashSet[Any]() @@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: DataType = set.dataType + override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] override def eval(input: Row): Any = { val itemEval = item.eval(input) @@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType + override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] override def symbol: String = "++=" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3cdca4e9dd2d1..d597bf7ce756a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -60,38 +57,17 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ @@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.asInstanceOf[UTF8String]) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def toString: String = s"Upper($child)" } @@ -149,29 +146,29 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryExpression => + self: BinaryPredicate => - type EvaluatedType = Any + override type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = BooleanType - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } @@ -184,24 +181,24 @@ trait StringComparison { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.contains(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.startsWith(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - override def compare(l: String, r: String): Boolean = l.endsWith(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -225,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -236,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -245,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -258,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..7c80634d2c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal.create(utf, StringType)) + } } } @@ -218,12 +223,12 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) - case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) + case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -235,36 +240,36 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => true } if (newChildren.length == 0) { - Literal(null, e.dataType) + Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren(0) } else { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } } @@ -284,13 +289,13 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(null), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { case Literal(candidate, _) if candidate == v => true case _ => false - } => Literal(true, BooleanType) + } => Literal.create(true, BooleanType) } } } @@ -647,7 +652,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 02f7c26a8ab6e..7967189cacb24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index bb79dc340553b..e3e070f0ff307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField} +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) @@ -31,7 +31,8 @@ object LocalRelation { def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) - LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema))) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b01a61d7bf8d6..ae4620a4e5abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -109,16 +109,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver) + def resolveChildren( + nameParts: Seq[String], + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, output, resolver) + def resolve( + nameParts: Seq[String], + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(nameParts, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -128,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsTableColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { assert(nameParts.length > 1) @@ -148,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { if (resolver(attribute.name, nameParts.head)) { @@ -160,11 +166,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( - name: String, + nameParts: Seq[String], input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - val parts = name.split("\\.") + resolver: Resolver, + throwErrors: Boolean): Option[NamedExpression] = { // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list @@ -174,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // and the second element will be List("c"). var candidates: Seq[(Attribute, List[String])] = { // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (parts.length > 1) { + if (nameParts.length > 1) { input.flatMap { option => - resolveAsTableColumn(parts, resolver, option) + resolveAsTableColumn(nameParts, resolver, option) } } else { Seq.empty @@ -186,24 +191,30 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // If none of attributes match `table.column` pattern, we try to resolve it as a column. if (candidates.isEmpty) { candidates = input.flatMap { candidate => - resolveAsColumn(parts, resolver, candidate) + resolveAsColumn(nameParts, resolver, candidate) } } + def name = UnresolvedAttribute(nameParts).name + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + try { + // The foldLeft adds GetFields for every remaining parts of the identifier, + // and aliases it with the last part of the identifier. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add GetField("c", GetField("b", a)), and alias + // the final expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver)) + val aliasName = nestedFields.last + Some(Alias(fieldExprs, aliasName)()) + } catch { + case a: AnalysisException if !throwErrors => None + } // No matches. case Seq() => @@ -212,7 +223,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") + val referenceNames = ambiguousReferences.map(_._1).mkString(", ") throw new AnalysisException( s"Reference '$name' is ambiguous, could be: $referenceNames.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4d9e41a2b5d85..17522976dc2c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -80,7 +80,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -125,12 +125,14 @@ case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = child.output + assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) @@ -147,6 +149,18 @@ case class CreateTableAsSelect[T]( override lazy val resolved: Boolean = databaseName != None && childrenResolved } +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * @param child The final query of this CTE. + * @param cteRelations Queries that this CTE defined, + * key is the alias of the CTE definition, + * value is the CTE definition. + */ +case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { @@ -287,7 +301,10 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case object NoRelation extends LeafNode { +/** + * A relation with one row. This is used in "SELECT ..." without a from clause. + */ +case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala deleted file mode 100644 index c243be07a91b6..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.text.SimpleDateFormat - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - -protected[sql] object DataTypeConversions { - - def productToRow(product: Product, schema: StructType): Row = { - val mutableRow = new GenericMutableRow(product.productArity) - val schemaFields = schema.fields.toArray - - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType) - i += 1 - } - - mutableRow - } - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } - - /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type - case (d: java.math.BigDecimal, _) => Decimal(d) - case (other, _) => other - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 89278f7dbc806..5163f05879e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -40,7 +40,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { protected lazy val primitiveType: Parser[DataType] = "(?i)string".r ^^^ StringType | "(?i)float".r ^^^ FloatType | - "(?i)int".r ^^^ IntegerType | + "(?i)(?:int|integer)".r ^^^ IntegerType | "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | @@ -112,4 +112,4 @@ private[sql] object DataTypeParser { } /** The exception thrown from the [[DataTypeParser]]. */ -protected[sql] class DataTypeException(message: String) extends Exception(message) +private[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 8a1a3b81b3d2c..d36a49159b87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import java.sql.Date +import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast @@ -39,6 +40,7 @@ object DateUtils { millisToDays(d.getTime) } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } @@ -57,4 +59,32 @@ object DateUtils { } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 0000000000000..fc02ba6c9c43e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,214 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import java.util.Arrays + +/** + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + + private[this] var bytes: Array[Byte] = _ + + /** + * Update the UTF8String with String. + */ + def set(str: String): UTF8String = { + bytes = str.getBytes("utf-8") + this + } + + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ + def set(bytes: Array[Byte]): UTF8String = { + this.bytes = bytes + this + } + + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 + } + + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + def length(): Int = { + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + i += numOfBytes(bytes(i)) + len += 1 + } + len + } + + def getBytes: Array[Byte] = { + bytes + } + + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + i += numOfBytes(bytes(i)) + c += 1 + } + var j = i + while (c < until && j < bytes.length) { + j += numOfBytes(bytes(j)) + c += 1 + } + UTF8String(Arrays.copyOfRange(bytes, i, j)) + } + + def contains(sub: UTF8String): Boolean = { + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false + } + + def startsWith(prefix: UTF8String): Boolean = { + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) + } + + def endsWith(suffix: UTF8String): Boolean = { + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) + } + + def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. + UTF8String(toString().toUpperCase) + } + + def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. + UTF8String(toString().toLowerCase) + } + + override def toString(): String = { + new String(bytes, "utf-8") + } + + override def clone(): UTF8String = new UTF8String().set(this.bytes) + + override def compare(other: UTF8String): Int = { + var i: Int = 0 + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) + if (res != 0) return res + i += 1 + } + bytes.length - b.length + } + + override def compareTo(other: UTF8String): Int = { + compare(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + Arrays.equals(bytes, s.getBytes) + case s: String => + // This is only used for Catalyst unit tests + // fail fast + bytes.length >= s.length && length() == s.length && toString() == s + case _ => + false + } + + override def hashCode(): Int = { + Arrays.hashCode(bytes) + } +} + +object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) + + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 952cf5c75688d..c6fb22c26bd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import scala.math._ import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -349,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = String + private[sql] type JvmType = UTF8String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -934,7 +935,9 @@ object StructType { case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt @@ -1193,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum * - * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - * where we need to convert Any to UserType. + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ def serialize(obj: Any): Any diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 46b2250aab231..ea82cd2622de9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -30,7 +30,7 @@ class DistributionSuite extends FunSuite { inputPartitioning: Partitioning, requiredDistribution: Distribution, satisfied: Boolean) { - if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) { fail( s""" |== Input Partitioning == @@ -40,6 +40,7 @@ class DistributionSuite extends FunSuite { |== Does input partitioning satisfy required distribution? == |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} """.stripMargin) + } } test("HashPartitioning is the output partitioning") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index eee00e3f7ea76..bbc0b661a0c0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -260,7 +260,7 @@ class ScalaReflectionSuite extends FunSuite { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { @@ -270,7 +270,7 @@ class ScalaReflectionSuite extends FunSuite { val dataType = schemaFor[OptionalData].dataType val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, Row(1, 1, 1, 1, 1, 1, true)) - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 756cd36f05c8c..e10ddfdf5127c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { override val extendedResolutionRules = EliminateSubQueries :: Nil } - val checkAnalysis = new CheckAnalysis + def caseSensitiveAnalyze(plan: LogicalPlan): Unit = + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) - def caseSensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseSensitiveAnalyzer(plan)) - - def caseInsensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseInsensitiveAnalyzer(plan)) + def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -134,7 +147,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { name: String, plan: LogicalPlan, errorMessages: Seq[String], - caseSensitive: Boolean = true) = { + caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { if(caseSensitive) { @@ -169,9 +182,27 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { "'b'" :: "group by" :: Nil ) + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false - override def output = Nil + override def output: Seq[Attribute] = Nil } errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d5865..67bec999dfbd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index ecbb54218d457..fcd745f43cfbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -96,7 +96,9 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(StringType, TimestampType, None) // ComplexType - widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) @@ -113,7 +115,9 @@ class HiveTypeCoercionSuite extends PlanTest { // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) // Stringify boolean when casting to string. - ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) + ruleTest( + Cast(Literal(false), StringType), + If(Literal(false), Literal("true"), Literal("false"))) } test("coalesce casts") { @@ -127,11 +131,11 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest( Coalesce(Literal(1.0) :: Literal(1) - :: Literal(1.0, FloatType) + :: Literal.create(1.0, FloatType) :: Nil), Coalesce(Cast(Literal(1.0), DoubleType) :: Cast(Literal(1), DoubleType) - :: Cast(Literal(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) ruleTest( Coalesce(Literal(1L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index dcfd8b28cb02a..76298f03c94ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,12 +25,44 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -class ExpressionEvaluationSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends FunSuite { + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} + +class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { + + def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } test("literals") { checkEvaluation(Literal(1), 1) @@ -55,10 +87,13 @@ class ExpressionEvaluationSuite extends FunSuite { assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) } + // scalastyle:off /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False @@ -75,7 +110,7 @@ class ExpressionEvaluationSuite extends FunSuite { * False True * Unknown Unknown */ - + // scalastyle:on val notTrueTable = (true, false) :: (false, true) :: @@ -84,7 +119,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - checkEvaluation(!Literal(v, BooleanType), answer) + checkEvaluation(!Literal.create(v, BooleanType), answer) } } @@ -128,38 +163,19 @@ class ExpressionEvaluationSuite extends FunSuite { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } } } - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } - test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation( + In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), + true) } test("Divide") { @@ -169,12 +185,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("Remainder") { @@ -184,12 +201,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Remainder(Literal(1), Literal(0)), null) checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("INSET") { @@ -216,14 +234,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(MaxOf(1L, 2L), 2L) checkEvaluation(MaxOf(2L, 1L), 2L) - checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + } + + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) } test("LIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType).like("a"), null) - checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) - checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -242,29 +270,29 @@ class ExpressionEvaluationSuite extends FunSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal(null, StringType), null) - checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) @@ -289,14 +317,14 @@ class ExpressionEvaluationSuite extends FunSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -375,7 +403,7 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) - checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) + checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) } test("date") { @@ -501,8 +529,10 @@ class ExpressionEvaluationSuite extends FunSuite { } test("array casting") { - val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -550,10 +580,10 @@ class ExpressionEvaluationSuite extends FunSuite { } test("map casting") { - val map = Literal( + val map = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal( + val map_notNull = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) @@ -611,14 +641,14 @@ class ExpressionEvaluationSuite extends FunSuite { } test("struct casting") { - val struct = Literal( + val struct = Literal.create( Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal( + val struct_notNull = Literal.create( Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), @@ -706,7 +736,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex casting") { - val complex = Literal( + val complex = Literal.create( Row( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), @@ -737,7 +767,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -749,34 +779,35 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c2.isNull, true, row) checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(Literal(1, ShortType).isNull, false) - checkEvaluation(Literal(1, ShortType).isNotNull, true) + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - checkEvaluation(Literal(null, ShortType).isNull, true) - checkEvaluation(Literal(null, ShortType).isNotNull, false) + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), - Literal("a", StringType), Literal("b", StringType)), "b", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) checkEvaluation(c1 in (c1, c2), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -787,9 +818,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) @@ -819,13 +850,13 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - )) + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -835,19 +866,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), - Literal(null, StringType)), null, row) + Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), - Literal(null, IntegerType)), null, row) + Literal.create(null, IntegerType)), null, row) - def quickBuildGetField(expr: Expression, fieldName: String) = { + def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -855,10 +888,12 @@ class ExpressionEvaluationSuite extends FunSuite { } } - def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) + def quickResolve(u: UnresolvedGetField): StructGetField = { + quickBuildGetField(u.child, u.fieldName) + } checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row) + checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) @@ -866,10 +901,11 @@ class ExpressionEvaluationSuite extends FunSuite { ) assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) - assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true) - assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -877,20 +913,21 @@ class ExpressionEvaluationSuite extends FunSuite { } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) val c4 = 'a.int.at(3) checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -901,19 +938,20 @@ class ExpressionEvaluationSuite extends FunSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) val c4 = 'a.double.at(3) checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) - checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -924,7 +962,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -934,9 +972,10 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -948,85 +987,115 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 <=> c4, false, row) checkEvaluation(c4 <=> c6, true, row) checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) - checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal(null, StringType), null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) checkEvaluation(c1 startsWith "a", true, row) checkEvaluation(c1 startsWith "b", false, row) checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) checkEvaluation(c1 endsWith "c", true, row) checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, create_row(null)) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1037,20 +1106,20 @@ class ExpressionEvaluationSuite extends FunSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -1058,22 +1127,25 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 & c2, 0, row) checkEvaluation(c1 | c2, 3, row) @@ -1081,3 +1153,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(~c1, -2, row) } } + +// TODO: Make the tests work with codegen. +class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebcd..bcc0c404d2cfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ef10c0aece716..4396bd0dda9a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -176,40 +176,39 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: expressions have null literals") { - val originalQuery = - testRelation - .select( - IsNull(Literal(null)) as 'c1, - IsNotNull(Literal(null)) as 'c2, + val originalQuery = testRelation.select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, - GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, - UnresolvedGetField( - Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), - "a") as 'c5, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem( + Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, + UnresolvedGetField( + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, - UnaryMinus(Literal(null, IntegerType)) as 'c6, - Cast(Literal(null), IntegerType) as 'c7, - Not(Literal(null, BooleanType)) as 'c8, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal(null, IntegerType), 1) as 'c9, - Add(1, Literal(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal(null, StringType), "abc") as 'c13, - Like("abc", Literal(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 - ) + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 + ) val optimized = Optimize(originalQuery.analyze) @@ -219,31 +218,31 @@ class ConstantFoldingSuite extends PlanTest { Literal(true) as 'c1, Literal(false) as 'c2, - Literal(null, IntegerType) as 'c3, - Literal(null, IntegerType) as 'c4, - Literal(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as 'c3, + Literal.create(null, IntegerType) as 'c4, + Literal.create(null, IntegerType) as 'c5, - Literal(null, IntegerType) as 'c6, - Literal(null, IntegerType) as 'c7, - Literal(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as 'c6, + Literal.create(null, IntegerType) as 'c7, + Literal.create(null, BooleanType) as 'c8, - Literal(null, IntegerType) as 'c9, - Literal(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as 'c9, + Literal.create(null, IntegerType) as 'c10, - Literal(null, BooleanType) as 'c11, - Literal(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as 'c11, + Literal.create(null, BooleanType) as 'c12, - Literal(null, BooleanType) as 'c13, - Literal(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as 'c13, + Literal.create(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15, + Literal.create(null, StringType) as 'c15, - Literal(null, StringType) as 'c16, - Literal(null, StringType) as 'c17, - Literal(null, StringType) as 'c18, + Literal.create(null, StringType) as 'c16, + Literal.create(null, StringType) as 'c17, + Literal.create(null, StringType) as 'c18, - Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as 'c19, + Literal.create(null, BooleanType) as 'c20 ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index ae99a3f9ba287..2f3704be59a9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -29,7 +29,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) val optimizedPlan = DefaultOptimizer(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55c6766520a1e..1448098c770aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -432,7 +432,8 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { z.join(x.join(y)) - .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && + ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } val optimized = Optimize(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 233e329cb2038..966bc9ada1e6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -52,7 +52,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2)) + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 48884040bfce7..e7cafcc96de87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ /** @@ -45,16 +45,17 @@ class PlanTest extends FunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) + if (normalized1 != normalized2) { fail( s""" |== FAIL: Plans do not match === |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) + """.stripMargin) + } } /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 11e6831b24768..1273921f6394c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -32,7 +32,7 @@ class SameResultSuite extends FunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) - def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ce92a2160b6..4eb8708335dcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children = optKey.toSeq - def nullable = true - def dataType = NullType + def children: Seq[Expression] = optKey.toSeq + def nullable: Boolean = true + def dataType: NullType = NullType override lazy val resolved = true type EvaluatedType = Any - def eval(input: Row) = null.asInstanceOf[Any] + def eval(input: Row): Any = null.asInstanceOf[Any] } class TreeNodeSuite extends FunSuite { @@ -90,7 +90,7 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 1ba21b64603ac..169125264a803 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -34,10 +34,12 @@ class DataTypeParserSuite extends FunSuite { } checkDataType("int", IntegerType) + checkDataType("integer", IntegerType) checkDataType("BooLean", BooleanType) checkDataType("tinYint", ByteType) checkDataType("smallINT", ShortType) checkDataType("INT", IntegerType) + checkDataType("INTEGER", IntegerType) checkDataType("bigint", LongType) checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 0000000000000..a22aa6f244c48 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +// scalastyle:off +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + + test("prefix") { + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + + test("suffix") { + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + + test("slice") { + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index ca4a127120b37..18584c2dcf797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData(dataIndex).cachedRepresentation.uncache(blocking) cachedData.remove(dataIndex) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4c80359cf07af..3235f85d5bbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -146,7 +146,7 @@ class DataFrame private[sql]( _: WriteToFile => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => - queryExecution.logical + queryExecution.analyzed } /** @@ -158,7 +158,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -166,7 +166,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get } } @@ -237,11 +237,11 @@ class DataFrame private[sql]( def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + - "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + colNames.mkString(", ")) + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => - apply(oldName).as(newName) + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) } select(newCols :_*) } @@ -273,7 +273,7 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) /** - * Prints the plans (logical and physical) to the console for debugging purpose. + * Prints the plans (logical and physical) to the console for debugging purposes. * @group basic */ def explain(extended: Boolean): Unit = { @@ -285,7 +285,7 @@ class DataFrame private[sql]( } /** - * Only prints the physical plan to the console for debugging purpose. + * Only prints the physical plan to the console for debugging purposes. * @group basic */ def explain(): Unit = explain(extended = false) @@ -319,6 +319,17 @@ class DataFrame private[sql]( */ def show(): Unit = show(20) + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * df.na.drop() + * }}} + * + * @group dfops + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + /** * Cartesian join with another [[DataFrame]]. * @@ -702,7 +713,7 @@ class DataFrame private[sql]( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributes = schema.toAttributes val rowFunction = - f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, None, logicalPlan) @@ -723,7 +734,7 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) @@ -893,7 +904,8 @@ class DataFrame private[sql]( */ override def repartition(numPartitions: Int): DataFrame = { sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) } /** @@ -941,13 +953,18 @@ class DataFrame private[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * memoized. Once called, it won't change even if you change any query planning related Spark SQL + * configurations (e.g. `spark.sql.shuffle.partitions`). * @group rdd */ - def rdd: RDD[Row] = { + lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + queryExecution.executedPlan.execute().mapPartitions { rows => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]) + } } /** @@ -963,8 +980,8 @@ class DataFrame private[sql]( def javaRDD: JavaRDD[Row] = toJavaRDD /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. * * @group basic */ @@ -1192,7 +1209,7 @@ class DataFrame private[sql]( @Experimental def insertInto(tableName: String, overwrite: Boolean): Unit = { sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd + Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd } /** @@ -1239,7 +1256,7 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. * If you pass `true` for `allowExisting`, it will drop any table with the * given name; if you pass `false`, it will throw if the table already @@ -1263,7 +1280,7 @@ class DataFrame private[sql]( } /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * Assumes the table already exists and has a compatible schema. If you * pass `true` for `overwrite`, it will `TRUNCATE` the table before * performing the `INSERT`s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala new file mode 100644 index 0000000000000..481ed4924857e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -0,0 +1,375 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.{lang => jl} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functionality for working with missing data in [[DataFrame]]s. + */ +@Experimental +final class DataFrameNaFunctions private[sql](df: DataFrame) { + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values. + */ + def drop(): DataFrame = drop("any", df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values. + * + * If `how` is "any", then drop rows containing any null values. + * If `how` is "all", then drop rows only if every column is null for that row. + */ + def drop(how: String): DataFrame = drop(how, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Seq[String]): DataFrame = { + how.toLowerCase match { + case "any" => drop(cols.size, cols) + case "all" => drop(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + */ + def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null + * values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than + * `minNonNulls` non-null values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { + // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + df.filter(Column(predicate)) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + */ + def fill(value: Double): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + */ + def fill(value: String): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[Double](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in specified string columns. + * If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in + * specified string columns. If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[String](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * import com.google.common.collect.ImmutableMap; + * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); + * }}} + */ + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * df.na.fill(Map( + * "A" -> "unknown", + * "B" -> 1.0 + * )) + * }}} + */ + def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { + replace[T](col, replacement.toMap : Map[T, T]) + } + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { + replace(cols.toSeq, replacement.toMap) + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", Map(1.0 -> 2.0)) + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", Map("UNKNOWN" -> "unnamed") + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", Map("UNKNOWN" -> "unnamed") + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: Map[T, T]): DataFrame = { + if (col == "*") { + replace0(df.columns, replacement) + } else { + replace0(Seq(col), replacement) + } + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) + + private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { + if (replacement.isEmpty || cols.isEmpty) { + return df + } + + // replacementMap is either Map[String, String] or Map[Double, Double] + val replacementMap: Map[_, _] = replacement.head._2 match { + case v: String => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } + + // targetColumnType is either DoubleType or StringType + val targetColumnType = replacement.head._1 match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: String => StringType + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) + if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { + replaceCol(f, replacementMap) + } else if (f.dataType == targetColumnType && shouldReplace) { + replaceCol(f, replacementMap) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + private def fill0(values: Seq[(String, Any)]): DataFrame = { + // Error handling + values.foreach { case (colName, replaceValue) => + // Check column name exists + df.resolve(colName) + + // Check data type + replaceValue match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + // This is good + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") + } + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => + v match { + case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Double => fillCol[Double](f, v) + case v: jl.Long => fillCol[Double](f, v.toDouble) + case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: String => fillCol[String](f, v) + } + }.getOrElse(df.col(f.name)) + } + df.select(projections : _*) + } + + /** + * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + */ + private def fillCol[T](col: StructField, replacement: T): Column = { + coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + } + + /** + * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with + * value in `replacementMap`, using [[CaseWhen]]. + * + * TODO: This can be optimized to use broadcast join when replacementMap is large. + */ + private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { + val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) => + df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr :: + lit(target).cast(col.dataType).expr :: Nil + }.toSeq + new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name) + } + + private def convertToDouble(v: Any): Double = v match { + case v: Float => v.toDouble + case v: Double => v + case v: Long => v.toDouble + case v: Int => v.toDouble + case v => throw new IllegalArgumentException( + s"Unsupported value type ${v.getClass.getName} ($v).") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45a63ae26ed71..53ad67372e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() @@ -127,10 +127,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.builder() - * .put("age", "max") - * .put("expense", "sum") - * .build()); + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); * }}} */ def agg(exprs: java.util.Map[String, String]): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4815620c6fe57..5c65f04ee8497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -39,12 +39,15 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -119,9 +122,20 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ + private[spark] def verifyPartitionPath = + getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e59cf9b9e037b..f9f3eb2e03817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,9 +31,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ @@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext) ExtractPythonUdfs :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } @transient @@ -177,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental @transient - lazy val emptyDataFrame = DataFrame(this, NoRelation) + lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) /** * A collection of methods for registering user-defined functions (UDF). @@ -388,9 +392,24 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + rowRDD.map(converter(_).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } @@ -441,7 +460,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { row => new GenericRow( extractors.zip(attributeSeq).map { case (e, attr) => - DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) }.toArray[Any] ) : Row } @@ -600,7 +619,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -629,7 +648,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -854,8 +873,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * passed to this function. * * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` to retrieve - * @param upperBound the maximum value of `columnName` to retrieve + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @@ -1062,17 +1081,9 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } - @transient - protected[sql] lazy val checkAnalysis = new CheckAnalysis { - override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) - ) - } - - protected[sql] def openSession(): SQLSession = { detachSession() val session = createSession() @@ -1105,7 +1116,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = checkAnalysis(analyzed) + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = analyzer(logical) lazy val withCachedData: LogicalPlan = { @@ -1184,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 0000000000000..d1ea7cc3e9162 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c881747751520..00ed70430b84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -153,6 +153,7 @@ private[sql] object ColumnBuilder { val builder: ColumnBuilder = dataType match { case IntegerType => new IntColumnBuilder case LongType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 87a6631da8300..b0f983c180673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats { } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c47497e0662d9..1b9e0df2dcb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer): String = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eee0c86d6a1c..d9b6fb43ab83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation( child: SparkPlan, tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null) + private var _statistics: Statistics = null, + private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats = - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + if (_batchStats == null) { + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + } else { + _batchStats + } val partitionStatistics = new PartitionStatistics(output) @@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated) + _cachedColumnBuffers, statisticsToBePropagated, batchStats) } override def children: Seq[LogicalPlan] = Seq.empty @@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - statisticsToBePropagated).asInstanceOf[this.type] + statisticsToBePropagated, + batchStats).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated) + Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) + + private[sql] def uncache(blocking: Boolean): Unit = { + Accumulators.remove(batchStats.id) + cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } } private[sql] case class InMemoryColumnarTableScan( @@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan( } } + lazy val enableAccumulators: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + // Accumulators used for testing purposes - val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning override def execute(): RDD[Row] = { - readPartitions.setValue(0) - readBatches.setValue(0) + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( @@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan( } } - if (rows.hasNext) { + if (rows.hasNext && enableAccumulators) { readPartitions += 1 } @@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan( logInfo(s"Skipping partition based on stats $statsString") false } else { - readBatches += 1 + if (enableAccumulators) { + readBatches += 1 + } true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..69a620e1ec929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,24 +19,42 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + newOrdering: Seq[SortOrder], + child: SparkPlan) + extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering: Seq[SortOrder] = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -45,7 +63,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { + lazy val sparkConf = child.sqlContext.sparkContext.getConf + newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -56,7 +90,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -69,12 +105,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -87,9 +128,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._1) case SinglePartition => @@ -107,7 +152,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -120,27 +165,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child.outputOrdering + } + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -157,28 +209,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( + partitioning: Partitioning, + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) + } else { + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } + } else { + withShuffle + } + + withSort + } + } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d8955725e59b1..1fd387eec7e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType - -import scala.collection.immutable +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -39,13 +37,42 @@ object RDDConversions { Iterator.empty } else { val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 + } + + mutableRow + } + } + } + } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) + val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) + mutableRow(i) = converters(i)(r(i)) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 89682d25ca7dc..b1ef6556de1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -68,6 +68,8 @@ case class GeneratedAggregate( a.collect { case agg: AggregateExpression => agg} } + // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite + // (in test "aggregation with codegen"). val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -93,13 +95,16 @@ case class GeneratedAggregate( } val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal(null, calcType) + val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) val updateFunction = Coalesce( - Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -109,8 +114,8 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => - val calcType = + case cs @ CombineSum(expr) => + val calcType = expr.dataType expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -118,45 +123,39 @@ case class GeneratedAggregate( expr.dataType } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val currentSum = AttributeReference("currentSum", calcType, nullable = false)() - val initialCount = Literal(0L) - val initialSum = Cast(Literal(0L), calcType) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { + val actualExpr = expr match { case UnscaledValue(e) => e case _ => expr } - - val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil) - + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + val result = expr.dataType match { case DecimalType.Fixed(_, _) => - If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), - Cast(Divide( - Cast(currentSum, DecimalType.Unlimited), - Cast(currentCount, DecimalType.Unlimited)), a.dataType)) - case _ => - If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), - Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) + Cast(currentSum, cs.dataType) + case _ => currentSum } - AggregateEvaluation( - currentCount :: currentSum :: Nil, - initialCount :: initialSum :: Nil, - updateCount :: updateSum :: Nil, - result - ) - + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal(null, expr.dataType) + val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) AggregateEvaluation( @@ -165,8 +164,20 @@ case class GeneratedAggregate( updateMax :: Nil, currentMax) + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + case CollectHashSet(Seq(expr)) => - val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val set = + AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() val initialValue = NewSet(expr.dataType) val addToSet = AddItemToSet(expr, set) @@ -177,9 +188,10 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => - val ArrayType(inputType, _) = inputSet.dataType - val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(inputType) + val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType + val set = + AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() + val initialValue = NewSet(elementType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( @@ -187,6 +199,8 @@ case class GeneratedAggregate( initialValue :: Nil, collectSets :: Nil, CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 5bd699a2fa949..8a8c3a404323a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.Attribute @@ -32,9 +32,15 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo override def execute(): RDD[Row] = rdd - override def executeCollect(): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).toArray + } + + + override def executeTake(limit: Int): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d239637cd4b4e..e159ffe66cb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{ScalaReflection, trees} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ @@ -80,8 +86,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ + def executeCollect(): Array[Row] = { - execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + execute().mapPartitions { iter => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + iter.map(converter(_).asInstanceOf[Row]) + }.collect() } /** @@ -125,7 +135,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + val converter = CatalystTypeConverters.createToScalaConverter(schema) + buf.toArray.map(converter(_).asInstanceOf[Row]) } protected def newProjection( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 967bd76b302d8..eea15aff5dbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.types.Decimal @@ -26,14 +27,13 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.{AllScalaRegistrar, ResourcePool} +import com.twitter.chill.ResourcePool import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} @@ -55,6 +55,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) kryo.register(classOf[Decimal]) + kryo.register(classOf[JavaHashMap[_, _]]) kryo.setReferences(false) kryo @@ -64,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co private[execution] class KryoResourcePool(size: Int) extends ResourcePool[SerializerInstance](size) { - val ser: KryoSerializer = { + val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - new KryoSerializer(sparkConf) + new SparkSqlSerializer(sparkConf) } def newInstance(): SerializerInstance = ser.newInstance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2b581152e5f77..e687d01f57520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -155,7 +163,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false @@ -211,9 +219,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil - case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: ParquetRelation, partition, child, overwrite, ifNotExists) => InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => + val partitionColNames = relation.partitioningAttributes.map(_.name).toSet + val filtersToPush = filters.filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } val prunePushedDownFilters = if (sqlContext.conf.parquetFilterPushDown) { (predicates: Seq[Expression]) => { @@ -225,6 +239,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // "A AND B" in the higher-level filter, not just "B". predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { case (predicate, None) => predicate + // Filter needs to be applied above when it contains partitioning + // columns + case (predicate, _) if(!predicate.references.map(_.name).toSet + .intersect (partitionColNames).isEmpty) => predicate } } } else { @@ -237,7 +255,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetTableScan( _, relation, - if (sqlContext.conf.parquetFilterPushDown) filters else Nil)) :: Nil + if (sqlContext.conf.parquetFilterPushDown) filtersToPush else Nil)) :: Nil case _ => Nil } @@ -296,10 +314,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil - case logical.NoRelation => + case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1f5251a20376f..d286fe81bee5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val resuableProjection = buildProjection() iter.map(resuableProjection) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -117,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan) } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } } @@ -139,13 +143,16 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) - // TODO: Is this copying for no reason? - override def executeCollect(): Array[Row] = - collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + collectData().map(converter(_).asInstanceOf[Row]) + } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -171,6 +178,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -193,7 +202,7 @@ case class ExternalSort( child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r, null))) + sorter.insertAll(iterator.map(r => (r.copy, null))) val baseIterator = sorter.iterator.map(_._1) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) @@ -201,6 +210,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index fad7a281dc1e2..99f24910fd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e916e68e58b5d..710787096e6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,7 +164,7 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => - case (_: String, StringType) => + case (_: UTF8String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2fa1cf5add3b5..ab84c123e0c0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.joins +import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -29,16 +31,43 @@ import org.apache.spark.util.collection.CompactBuffer */ private[joins] sealed trait HashedRelation { def get(key: Row): CompactBuffer[Row] + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) - extends HashedRelation with Serializable { +private[joins] final class GeneralHashedRelation( + private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } @@ -46,8 +75,10 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) - extends HashedRelation with Serializable { +private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) @@ -55,6 +86,14 @@ private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, R } def getValue(key: Row): Row = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..b5123668ba11e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + override def execute(): RDD[Row] = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5b308d88d4cdf..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -192,7 +193,8 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 111e751588a8b..ff91e1d74bc2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -605,4 +605,23 @@ object functions { } // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + */ + def callUdf(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 1704be7fcbd30..0feabc4282f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -49,9 +49,9 @@ private[sql] object DriverQuirks { * Fetch the DriverQuirks class corresponding to a given database url. */ def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { + if (url.startsWith("jdbc:mysql")) { new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { + } else if (url.startsWith("jdbc:postgresql")) { new PostgresQuirks() } else { new NoQuirks() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268bc..b9022fcd9e3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -233,7 +233,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } @@ -349,12 +349,14 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + // TODO(davies): convert Date into Int case DateConversion => mutableRow.update(i, rs.getDate(pos)) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7e..99b755c9f25d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -130,6 +130,8 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { + override val needConversion: Boolean = false + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7a..d4e0abc040bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} +import org.apache.spark.Logging import org.apache.spark.sql.types._ package object jdbc { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f4c99b4b56606..e3352d02787fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row - -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource @@ -113,6 +113,8 @@ private[sql] case class JSONRelation( // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) + override val needConversion: Boolean = false + override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 2b0358c4e2a1e..29de7401dda71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -49,7 +49,7 @@ private[sql] object JsonRDD extends Logging { val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = if (schemaData.isEmpty()) { - Set.empty[(String,DataType)] + Set.empty[(String, DataType)] } else { parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) } @@ -391,7 +391,7 @@ private[sql] object JsonRDD extends Logging { value match { // only support string as date case value: java.lang.String => - DateUtils.millisToDays(DataTypeConversions.stringToTime(value).getTime) + DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) case value: java.sql.Date => DateUtils.fromJavaDate(value) } } @@ -400,7 +400,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) } } @@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b51735..bc108e37dfb0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, readTimestamp(value)) @@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter( private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - private[this] var dict: Array[String] = null + private[this] var dict: Array[Array[Byte]] = null override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) + parent.updateString(fieldIndex, value.getBytes) } private[parquet] object CatalystArrayConverter { @@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 0357dcc4688be..5eb1c6abc2432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5130d8ad5e003..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet import java.io.IOException import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -43,12 +42,13 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** @@ -356,7 +356,7 @@ private[sql] case class InsertIntoParquetTable( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) @@ -512,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d273..e05a4c20b0d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) @@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index da668f068613b..60e1bec4db8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -390,6 +390,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromAttributes(attributes: Seq[Attribute], toThriftSchemaNames: Boolean = false): MessageType = { + checkSpecialCharacters(attributes) val fields = attributes.map( attribute => fromDataType(attribute.dataType, attribute.name, attribute.nullable, @@ -404,7 +405,20 @@ private[parquet] object ParquetTypesConverter extends Logging { } } + private def checkSpecialCharacters(schema: Seq[Attribute]) = { + // ,;{}()\n\t= and space character are special characters in Parquet schema + schema.map(_.name).foreach { name => + if (name.matches(".*[ ,;{}()\n\t=].*")) { + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + } + } + def convertToString(schema: Seq[Attribute]): String = { + checkSpecialCharacters(schema) StructType.fromAttributes(schema).json } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 0d68810ec6043..af7b3c81ae7b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} - import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.metadata.CompressionCodecName @@ -42,15 +41,16 @@ import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} /** * Allows creation of Parquet based tables using the syntax: @@ -121,7 +121,8 @@ private[sql] class DefaultSource val df = sqlContext.createDataFrame( data.queryExecution.toRdd, - data.schema.asNullable) + data.schema.asNullable, + needsConversion = false) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) @@ -266,7 +267,8 @@ private[sql] case class ParquetRelation2( // containing Parquet files (e.g. partitioned Parquet table). val baseStatuses = paths.distinct.map { p => val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val qualified = fs.makeQualified(new Path(p)) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) if (!fs.exists(qualified) && maybeSchema.isDefined) { fs.mkdirs(qualified) @@ -406,6 +408,9 @@ private[sql] case class ParquetRelation2( file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + // Skip type conversion + override val needConversion: Boolean = false + // TODO Should calculate per scan size // It's common that a query only scans a fraction of a large Parquet file. Returning size of the // whole Parquet file disables some optimizations in this case (e.g. broadcast join). @@ -429,20 +434,24 @@ private[sql] case class ParquetRelation2( // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) + FileInputFormat.setInputPaths(job, selectedPaths: _*) } - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. + // Try to push down filters when filter push-down is enabled. if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet predicates // Don't push down predicates which reference partition columns .filter { pred => - val partitionColNames = partitionColumns.map(_.name).toSet val referencedColNames = pred.references.map(_.name).toSet referencedColNames.intersect(partitionColNames).isEmpty } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. .flatMap(ParquetFilters.createFilter) .reduceOption(FilterApi.and) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) @@ -480,10 +489,31 @@ private[sql] case class ParquetRelation2( val cacheMetadata = useCache @transient - val cachedStatus = selectedFiles + val cachedStatus = selectedFiles.map { st => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val newPath = new Path(st.getPath.toUri.toString) + + new FileStatus( + st.getLen, + st.isDir, + st.getReplication, + st.getBlockSize, + st.getModificationTime, + st.getAccessTime, + st.getPermission, + st.getOwner, + st.getGroup, + newPath) + } @transient - val cachedFooters = selectedFooters + val cachedFooters = selectedFooters.map { f => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { @@ -522,7 +552,8 @@ private[sql] case class ParquetRelation2( baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => p.values + case p if split.getPath.getParent.toString == p.path => + CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] }.get val requiredPartOrdinal = partitionKeyLocations.keys.toSeq @@ -669,7 +700,8 @@ private[sql] case class ParquetRelation2( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) /* apparently we need a TaskAttemptID to construct an OutputCommitter; @@ -765,12 +797,14 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size <= parquetSchema.size, schemaConflictMessage) + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { @@ -782,6 +816,32 @@ private[sql] object ParquetRelation2 extends Logging { }) } + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). // However, we are already using Catalyst expressions for partition pruning and predicate // push-down here... @@ -842,9 +902,9 @@ private[sql] object ParquetRelation2 extends Logging { * PartitionValues( * Seq("a", "b", "c"), * Seq( - * Literal(42, IntegerType), - * Literal("hello", StringType), - * Literal(3.14, FloatType))) + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) * }}} */ private[parquet] def parsePartition( @@ -923,15 +983,16 @@ private[sql] object ParquetRelation2 extends Logging { raw: String, defaultPartitionName: String): Literal = { // First tries integral types - Try(Literal(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types - .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) } } @@ -950,7 +1011,7 @@ private[sql] object ParquetRelation2 extends Logging { } literals.map { case l @ Literal(_, dataType) => - Literal(Cast(l, desiredType).eval(), desiredType) + Literal.create(Cast(l, desiredType).eval(), desiredType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 67f3507c61ab6..b3d71f687a60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -52,10 +54,10 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if part.isEmpty => + l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil @@ -101,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. @@ -166,6 +178,15 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) + + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) + + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) + case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba30..dbdb0d39c26a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index eb46b46ca5bf4..2e861b84b7133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import scala.util.matching.Regex import scala.language.implicitConversions import org.apache.spark.Logging @@ -155,7 +156,19 @@ private[sql] class DDLParser( protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => @@ -204,7 +217,7 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) - def className = clazz.getCanonicalName + def className: String = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 1e4505e36d2f0..791046e0079d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,16 +17,85 @@ package org.apache.spark.sql.sources +/** + * A filter predicate for data sources. + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + */ case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + */ case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + */ case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + */ case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + */ case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a046a48c1733d..ca53dcdb92c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -126,6 +126,16 @@ abstract class BaseRelation { * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + */ + def needConversion: Boolean = true } /** @@ -152,6 +162,9 @@ trait PrunedScan { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 5a78001117d1b..6ed68d179edc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -37,7 +37,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // We are inserting into an InsertableRelation. case i @ InsertIntoTable( - l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite) => { + l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -84,7 +84,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => + l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite, ifNotExists) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -102,7 +102,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } case i @ logical.InsertIntoTable( - l: LogicalRelation, partition, query, overwrite) if !l.isInstanceOf[InsertableRelation] => + l: LogicalRelation, partition, query, overwrite, ifNotExists) + if !l.isInstanceOf[InsertableRelation] => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c11d0ae5bf1cc..2fdd798b44bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1ff2d5a190521..6d0fbe83c2f36 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,6 +20,8 @@ import java.io.Serializable; import java.util.Arrays; +import scala.collection.Seq; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -127,6 +129,12 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("b")); Row first = df.select("a", "b").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - Assert.assertArrayEquals(bean.getB(), first.getAs(1)); + // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // verify that it has the expected length, and contains expected elements. + Seq result = first.getAs(1); + Assert.assertEquals(bean.getB().length, result.length()); + for (int i = 0; i < result.length(); i++) { + Assert.assertEquals(bean.getB()[i], result.apply(i)); + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c240f2be955ca..01e3b8671071e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ @@ -92,7 +93,8 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + .registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) @@ -296,4 +298,21 @@ class CachedTableSuite extends QueryTest { sql("Clear CACHE") assert(cacheManager.isEmpty) } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + val accsSize = Accumulators.originals.size + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + uncacheTable("t1") + uncacheTable("t2") + + assert(accsSize >= Accumulators.originals.size) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 0000000000000..41b4f02e6a294 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameNaFunctionsSuite extends QueryTest { + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Amy", null, null), + (null, null, null)).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil), + rows(0) :: rows(2) :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all"), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")), + rows(0) :: rows(1) :: rows(2) :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil), + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, null) :: + Row("Amy", 50, null) :: + Row(null, 50, null) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(mapAsJavaMap(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ))), + Row("test", null, 1, 2.2)) + } + + test("replace") { + val input = createDF() + + // Replace two numeric columns: age and height + val out = input.na.replace(Seq("age", "height"), Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 461.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + + // Replace only the age column + val out1 = input.na.replace("age", Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out1, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 164.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index fbc4065a9666c..b26e22f6229fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.sql @@ -60,6 +60,14 @@ class DataFrameSuite extends QueryTest { assert($"test".toString === "test") } + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + test("invalid plan toString, debug mode") { val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") @@ -84,6 +92,11 @@ class DataFrameSuite extends QueryTest { testData.collect().toSeq) } + test("empty data frame") { + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) + } + test("head and take") { assert(testData.take(2) === testData.collect().take(2)) assert(testData.head(2) === testData.collect().take(2)) @@ -113,6 +126,10 @@ class DataFrameSuite extends QueryTest { checkAnswer( df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } test("explode") { @@ -312,8 +329,9 @@ class DataFrameSuite extends QueryTest { checkAnswer( decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) + // non-partial checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } @@ -422,6 +440,15 @@ class DataFrameSuite extends QueryTest { ) } + test("call udf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( @@ -497,4 +524,11 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show() testData.select($"*").show(1000) } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..037d392c1f929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } } test("broadcasted hash join operator selection") { cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9b4dd6c620fec..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,7 +67,7 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 36465cc2fa11a..bf6cf1321a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a3c0076e16d6c..d739e550f3e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -102,11 +99,105 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() + // Prepare a table that we can group some rows. + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } @@ -182,7 +273,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"), + """ + |SELECT time FROM timestamps + |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') + """.stripMargin), Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) @@ -248,7 +342,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1")) } - def sortTest() = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) @@ -318,6 +412,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("CTE feature") { + checkAnswer( + sql("with q1 as (select * from testData limit 10) select * from q1"), + testData.take(10).toSeq) + + checkAnswer( + sql(""" + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), + Row(5, "5") :: Row(4, "4") :: Nil) + + } + + test("Allow only a single WITH clause per query") { + intercept[RuntimeException] { + sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + } + } + test("date row") { checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), @@ -327,7 +441,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("from follow multiple brackets") { checkAnswer(sql( - "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + """ + |select key from ((select * from testData limit 1) + | union all (select * from testData limit 1)) x limit 1 + """.stripMargin), Row(1) ) @@ -337,7 +454,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer(sql( - "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + """ + |select key from + | (select * from testData limit 1 union all select * from testData limit 1) x + | limit 1 + """.stripMargin), Row(1) ) } @@ -384,7 +505,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), + sql( + """ + |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 + """.stripMargin), Row(2, 1, 2, 2, 1)) } @@ -997,9 +1121,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = sparkContext.parallelize( + Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1082,12 +1207,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: ORDER BY test for nested fields") { + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + .registerTempTable("nestedOrder") + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { jsonRDD(sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") - // These should be successfully analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed + """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + } + + test("SPARK-6898: complete support for special chars in column names") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 17e923ca48502..3fa00fd9d0ccb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -80,7 +80,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectData") @@ -103,7 +103,8 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fe618e0e8e767..2672e20deadc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,13 +23,16 @@ import org.apache.spark.util.Utils import scala.beans.{BeanInfo, BeanProperty} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - +import org.apache.spark.util.collection.OpenHashSet @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { @@ -63,7 +66,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } } - override def userClass = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this } @@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest { df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } + + test("HyperLogLogUDT") { + val hyperLogLogUDT = HyperLogLogUDT + val hyperLogLog = new HyperLogLog(0.4) + (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + + val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) + assert(actual.cardinality() === hyperLogLog.cardinality()) + assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + } + + test("OpenHashSetUDT") { + val openHashSetUDT = new OpenHashSetUDT(IntegerType) + val set = new OpenHashSet[Int] + (1 to 10).foreach(i => set.add(i)) + + val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) + assert(actual.iterator.toSet === set.iterator.toSet) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 5f08834f73c6b..b48bed1871c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import java.sql.Timestamp +import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.serializer.KryoRegistrator import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -65,7 +68,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) @@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType[BooleanType.type]( @@ -108,8 +111,8 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { - val bytes = string.getBytes("utf-8") + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -117,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging { } } + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() + + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + } + } + def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, @@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging { } } } + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a,b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index c7a40845db16c..f76314b9dab5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,10 +24,10 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{Decimal, DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} object ColumnarTestUtils { - def makeNullRow(length: Int) = { + def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row @@ -48,7 +48,7 @@ object ColumnarTestUtils { case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => Random.nextString(Random.nextInt(32)) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() @@ -93,7 +93,7 @@ object ColumnarTestUtils { def makeUniqueValuesAndSingleValueRows[T <: NativeType]( columnType: NativeColumnType[T], - count: Int) = { + count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 27dfabca90217..56591d9dba29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.columnar +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{DecimalType, Decimal} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -42,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { .toDF().registerTempTable("sizeTst") cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.logical.statistics.sizeInBytes > + table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > conf.autoBroadcastJoinThreshold) } @@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, (1 to i).toSeq)) + } + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index bb305355276bf..a0702144f942c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -31,7 +31,8 @@ class TestNullableColumnAccessor[T <: DataType, JvmType]( with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) + : TestNullableColumnAccessor[T, JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 75a47498683f4..3a5605d2335d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -27,7 +27,8 @@ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[T, JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index e57bb06e7263b..2a0b701cad7fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + // Enable in-memory table scan accumulators + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 0b18b4119268f..fc8ff3b41d0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -35,7 +35,7 @@ object TestCompressibleColumnBuilder { def apply[T <: NativeType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - scheme: CompressionScheme) = { + scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) builder.initialize(0, "", useCompression = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4e9472c60249e..358d8cf06e463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -30,4 +30,4 @@ class DebuggingSuite extends FunSuite { test("DataFrame.typeCheck()") { testData.typeCheck() } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 592ed4b23b7d3..3596b183d4328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -45,10 +45,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( @@ -132,25 +134,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT *") { - assert(sql("SELECT * FROM foobar").collect().size == 3) + assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2) + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) } test("SELECT * WHERE (quoted strings)") { - assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1) + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) - assert(names.size == 3) + assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) @@ -158,10 +160,10 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("SELECT * partitioned") { @@ -169,46 +171,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size == 3) + .collect.size === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size == 1) - assert(rows(0).getInt(0) == 1) - assert(rows(0).getBoolean(1) == false) - assert(rows(0).getInt(2) == 3) - assert(rows(0).getInt(3) == 4) - assert(rows(0).getLong(4) == 1234567890123L) + assert(rows.size === 1) + assert(rows(0).getInt(0) === 1) + assert(rows(0).getBoolean(1) === false) + assert(rows(0).getInt(2) === 3) + assert(rows(0).getInt(3) === 4) + assert(rows(0).getLong(4) === 1234567890123L) } test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size == 1) + assert(rows.size === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -230,27 +232,27 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) - assert(cal.get(Calendar.HOUR_OF_DAY) == 12) - assert(cal.get(Calendar.MINUTE) == 34) - assert(cal.get(Calendar.SECOND) == 56) + assert(cal.get(Calendar.HOUR_OF_DAY) === 12) + assert(cal.get(Calendar.MINUTE) === 34) + assert(cal.get(Calendar.SECOND) === 56) cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) - assert(cal.get(Calendar.YEAR) == 1996) - assert(cal.get(Calendar.MONTH) == 0) - assert(cal.get(Calendar.DAY_OF_MONTH) == 1) + assert(cal.get(Calendar.YEAR) === 1996) + assert(cal.get(Calendar.MONTH) === 0) + assert(cal.get(Calendar.DAY_OF_MONTH) === 1) cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) - assert(cal.get(Calendar.YEAR) == 2002) - assert(cal.get(Calendar.MONTH) == 1) - assert(cal.get(Calendar.DAY_OF_MONTH) == 20) - assert(cal.get(Calendar.HOUR) == 11) - assert(cal.get(Calendar.MINUTE) == 22) - assert(cal.get(Calendar.SECOND) == 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543) + assert(cal.get(Calendar.YEAR) === 2002) + assert(cal.get(Calendar.MONTH) === 1) + assert(cal.get(Calendar.DAY_OF_MONTH) === 20) + assert(cal.get(Calendar.HOUR) === 11) + assert(cal.get(Calendar.MINUTE) === 22) + assert(cal.get(Calendar.SECOND) === 33) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) } test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. + assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. assert(rows(0).getAs[BigDecimal](2) .equals(new BigDecimal("123456789012345.54321543215432100000"))) } @@ -264,7 +266,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() - assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 706c966ee05f5..fd0e2746dc045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -380,8 +380,10 @@ class JsonSuite extends QueryTest { sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("false", 21474836470L, + new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, + new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -404,7 +406,8 @@ class JsonSuite extends QueryTest { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: + Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType @@ -892,8 +895,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 DataFrame to JSON") - { + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: StructField("f2", StringType, false) :: @@ -913,8 +915,10 @@ class JsonSuite extends QueryTest { df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() + // scalastyle:off assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + // scalastyle:on val schema2 = StructType( StructField("f1", StructType( @@ -968,7 +972,8 @@ class JsonSuite extends QueryTest { // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) @@ -1008,7 +1013,8 @@ class JsonSuite extends QueryTest { // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + "from complexTable"), Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 6a2c2a7c4080a..10d0ede4dc0dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -22,7 +22,7 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext @@ -350,4 +350,26 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6742: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( + path, + Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, + Seq(AttributeReference("part", IntegerType, false)()) )) + + checkAnswer( + df.filter("a = 1 or part = 1"), + (1 to 3).map(i => Row(1, i, i.toString))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 203bc79f153dd..97c0f439acf13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -218,7 +218,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("compression codec") { - def compressionCodecFor(path: String) = { + def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter .readMetaData(new Path(path), Some(configuration)) .getBlocks diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index adb3c9391f6c2..b7561ce7298cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -45,11 +45,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) } - check("10", Literal(10, IntegerType)) - check("1000000000000000", Literal(1000000000000000L, LongType)) - check("1.5", Literal(1.5, FloatType)) - check("hello", Literal("hello", StringType)) - check(defaultPartitionName, Literal(null, NullType)) + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, FloatType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { @@ -75,22 +75,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "file://path/a=10", PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType)))) check( "file://path/a=10/b=hello/c=1.5", PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( - Literal(10, IntegerType), - Literal("hello", StringType), - Literal(1.5, FloatType)))) + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, FloatType)))) check( "file://path/a=10/b_hello/c=1.5", PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, FloatType)))) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 8462f9bb2d620..c964b6d984557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -180,10 +180,12 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { val caseClassString = "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + // scalastyle:off val jsonString = """ |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} """.stripMargin + // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) val fromJson = ParquetTypesConverter.convertFromString(jsonString) @@ -226,22 +228,54 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count + // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), - StructField("lowerCase", BinaryType))), + StructField("lowerCase", BinaryType, nullable = false))), StructType(Seq( StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) - // Conflicting field names + // Conflicting non-nullable field names intercept[Throwable] { ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType))), + StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54af50c6e10ad..ca25751b9583d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -24,17 +25,17 @@ class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } -case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, - new MetadataBuilder().putString("comment", "test comment").build()), + new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), @@ -57,8 +58,9 @@ case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLConte )) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to). - map(e => Row(s"people$e", e * 2)) + override def buildScan(): RDD[Row] = { + sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + } } class DDLTestSuite extends DataSourceTest { @@ -71,7 +73,8 @@ class DDLTestSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | Table 'test1' |) """.stripMargin) } @@ -79,7 +82,7 @@ class DDLTestSuite extends DataSourceTest { sqlTest( "describe ddlPeople", Seq( - Row("intType", "int", "test comment"), + Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 91c6367371f15..33c67355967dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -32,6 +32,10 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { override val extendedResolutionRules = PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index ffeccf0b69394..cb5e5147ff189 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -35,20 +36,25 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL extends BaseRelation with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters - def translateFilter(filter: Filter): Int => Boolean = filter match { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v @@ -57,13 +63,27 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) case IsNull("a") => (a: Int) => false // Int can't be null case IsNotNull("a") => (a: Int) => true - case Not(pred) => (a: Int) => !translateFilter(pred)(a) - case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a) - case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a) + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) case _ => (a: Int) => true } - def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -93,7 +113,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -128,41 +149,53 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", Seq.empty[Row]) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", (1 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", - Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", (6 to 10).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) @@ -193,6 +226,15 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 08fb5380dc026..6a1ddf2f8e98b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -34,12 +35,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo extends BaseRelation with PrunedScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String]) = { + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 43bc8eb2d11a7..cb287ba85c1f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -114,4 +114,4 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { message.contains("Append mode is not supported"), "We should complain that 'Append mode is not supported' for JSON source.") } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 7928600ac2fb5..3b47b8adf313b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.sql.{Timestamp, Date} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -35,10 +36,10 @@ class SimpleScanSource extends RelationProvider { case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -57,9 +58,9 @@ case class AllDataTypesScan( extends BaseRelation with TableScan { - override def schema = userSpecifiedSchema + override def schema: StructType = userSpecifiedSchema - override def buildScan() = { + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", @@ -73,7 +74,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -81,7 +82,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -102,7 +103,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -110,7 +111,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -265,7 +266,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a96b1ffc26966..f38c796241df1 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -44,7 +44,6 @@ com.google.guava guava - runtime ${hive.group} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6272cdedb3e48..85281c6d73a3b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -145,6 +145,9 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + // use the specified database if specified + cli.processSelectDatabase(sessionState); + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) @@ -264,7 +267,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 158c225159720..97b46a01ba5b4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -37,7 +38,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") sparkConf - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .setAppName(s"SparkSQL::${Utils.localHostName()}") .set("spark.sql.hive.version", HiveShim.version) .set( "spark.serializer", diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 75738fa22b572..b070fa8eaa469 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -26,22 +25,31 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { +class CliSuite extends FunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + } + def runCliWithin( timeout: FiniteDuration, extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = Utils.createTempDir() - warehousePath.delete() - val metastorePath = Utils.createTempDir() - metastorePath.delete() val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { @@ -96,8 +104,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin, cause) throw cause } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -125,4 +131,24 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { test("Single command with -e") { runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } + + test("Single command with --database") { + runCliWithin(1.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index bf20acecb1f32..4cf95e7bdfb2b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File +import java.net.URL import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -41,7 +42,7 @@ import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils object TestData { - def getTestDataFilePath(name: String) = { + def getTestDataFilePath(name: String): URL = { Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") } @@ -50,7 +51,7 @@ object TestData { } class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.binary + override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { // Transport creation logics below mimics HiveConnection.createBinaryTransport @@ -337,7 +338,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.http + override def mode: ServerMode.Value = ServerMode.http test("JDBC query execution") { withJdbcStatement { statement => diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2ae9d018e1b1b..81ee48ef4152f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -532,6 +532,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl7", "inputddl8", "insert1", + "insert1_overwrite_partitions", "insert2_overwrite_partitions", "insert_compressed", "join0", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 0000000000000..65d070bd3cbde --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a9816f6c38cd2..04440076a26a3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -89,6 +89,20 @@ junit test + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c06c2e396bbc1..7c6a7df2bd01e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -57,6 +57,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + /** * When true, a table created by a Hive CTAS statement (no USING clause) will be * converted to a data source table, using the data source set by spark.sql.sources.default. @@ -172,12 +181,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..74ae984f34866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,9 +239,10 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -284,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -340,7 +344,9 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) @@ -409,7 +415,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) @@ -593,7 +599,7 @@ private[hive] trait HiveInspectors { case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d1a99555e90c6..f1c0bd92aa23d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.metastore.{TableType, Warehouse} @@ -32,7 +33,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Catalog, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NoSuchTableException, Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -66,10 +67,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = synchronized { + val table = HiveMetastoreCatalog.this.synchronized { client.getTable(in.database, in.name) } - val userSpecifiedSchema = + + def schemaStringFromParts: Option[String] = { Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => val parts = (0 until numParts.toInt).map { index => val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") @@ -81,10 +83,19 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with part } - // Stick all parts back to a single schema string in the JSON representation - // and convert it back to a StructType. - DataType.fromJson(parts.mkString).asInstanceOf[StructType] + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -105,7 +116,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } override def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -172,12 +191,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = try client.getTable(databaseName, tblName) catch { + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => throw new NoSuchTableException } @@ -199,7 +222,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } @@ -211,11 +236,50 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. - if (metastoreRelation.hiveQlTable.isPartitioned) { + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) val partitions = metastoreRelation.hiveQlPartitions.map { p => @@ -227,19 +291,31 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } val partitionSpec = PartitionSpec(partitionSchema, partitions) val paths = partitions.map(_.path) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json), - None, - Some(partitionSpec))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } + + result.newInstance() } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { @@ -451,7 +527,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Collects all `MetastoreRelation`s which should be replaced val toBeReplaced = plan.collect { // Write path - case InsertIntoTable(relation: MetastoreRelation, _, _, _) + case InsertIntoTable(relation: MetastoreRelation, _, _, _, _) // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && @@ -459,10 +535,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Write path - case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) + case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _, _) // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && @@ -470,7 +546,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -479,34 +555,29 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) } - // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and - // their output attributes as the key of the map. This is because MetastoreRelation.equals - // doesn't take output attributes into account, thus multiple MetastoreRelation instances - // pointing to the same table get collapsed into a single entry in the map. A proper fix for - // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) val alias = r.alias.getOrElse(r.tableName) Subquery(alias, parquetRelation) - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) - InsertIntoTable(parquetRelation, partition, child, overwrite) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) - InsertIntoTable(parquetRelation, partition, child, overwrite) + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) case other => other.transformExpressions { case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) @@ -627,7 +698,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => castChildOutput(p, table, child) } @@ -644,7 +715,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) + InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -682,7 +753,8 @@ private[hive] case class InsertIntoHiveTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil @@ -697,10 +769,23 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { self: Product => + override def equals(other: scala.Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -778,6 +863,10 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def newInstance(): MetastoreRelation = { + MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c45c4ad70fae9..fd305eb480e63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.sql.Date +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} + import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf @@ -111,13 +113,16 @@ private[hive] object HiveQl { "TOK_REVOKE", + "TOK_SHOW_COMPACTIONS", "TOK_SHOW_CREATETABLE", "TOK_SHOW_GRANT", "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", "TOK_SHOW_ROLES", "TOK_SHOW_SET_ROLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", "TOK_SHOWCOLUMNS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", @@ -479,7 +484,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = @@ -574,11 +579,23 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - val (fromClause: Option[ASTNode], insertClauses) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - (Some(args.head), insertClauses) - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs) - } + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) + } + + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -622,7 +639,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val relations = fromClause match { case Some(f) => nodeToRelation(f) - case None => NoRelation + case None => OneRowRelation } val withWhere = whereClause.map { whereNode => @@ -659,7 +676,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) = clause match { + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) @@ -791,7 +809,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } // If there are multiple INSERTS just UNION them together into on query. - queries.reduceLeft(Union) + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) @@ -981,7 +1002,27 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite) + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.getChildren.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") @@ -1060,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) case other => UnresolvedGetField(other, attr) } @@ -1201,7 +1242,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C CreateArray(children.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) @@ -1213,9 +1254,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1226,21 +1267,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing @@ -1283,8 +1324,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Explode(attributes, nodeToExpr(child)) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUdtf( - new HiveFunctionWrapper(functionName), + new HiveFunctionWrapper(functionClassName), attributes, children.map(nodeToExpr)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5f7e897295117..a6f4fbe8aba06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,24 +17,21 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -131,7 +128,7 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) @@ -184,12 +181,14 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil - case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) => + table, partition, planLater(child), overwrite, ifNotExists) :: Nil + case hive.InsertIntoHiveTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil + table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 3563472c7ae81..e556c74ffb015 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DateUtils +import org.apache.spark.util.Utils /** * A trait for subclasses that handle table scans. @@ -76,7 +77,9 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + Class.forName( + relation.tableDesc.getSerdeClassName, true, sc.sessionState.getConf.getClassLoader) + .asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -142,7 +145,46 @@ class HadoopTableReader( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[Row] = { - val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + + // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists + def verifyPartitionPath( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): + Map[HivePartition, Class[_ <: Deserializer]] = { + if (!sc.conf.verifyPartitionPath) { + partitionToDeserializer + } else { + var existPathSet = collection.mutable.Set[String]() + var pathPatternSet = collection.mutable.Set[String]() + partitionToDeserializer.filter { + case (partition, partDeserializer) => + def updateExistPathSetByPathPattern(pathPatternStr: String) { + val pathPattern = new Path(pathPatternStr) + val fs = pathPattern.getFileSystem(sc.hiveconf) + val matches = fs.globStatus(pathPattern) + matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) + } + // convert /demo/data/year/month/day to /demo/data/*/*/*/ + def getPathPatternByPath(parNum: Int, tempPath: Path): String = { + var path = tempPath + for (i <- (1 to parNum)) path = path.getParent + val tails = (1 to parNum).map(_ => "*").mkString("/", "/", "/") + path.toString + tails + } + + val partPath = HiveShim.getDataLocationPath(partition) + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); + var pathPatternStr = getPathPatternByPath(partNum, partPath) + if (!pathPatternSet.contains(pathPatternStr)) { + pathPatternSet += pathPatternStr + updateExistPathSetByPathPattern(pathPatternStr) + } + existPathSet.contains(partPath.toString) + } + } + } + + val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) + .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = HiveShim.getDataLocationPath(partition) val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index fade9e5852eaa..76a1965f3cb25 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -67,7 +67,7 @@ case class CreateTableAsSelect( new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index da53d30354551..89995a91b1a92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -45,12 +45,13 @@ case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], child: SparkPlan, - overwrite: Boolean) extends UnaryNode with HiveInspectors { + overwrite: Boolean, + ifNotExists: Boolean) extends UnaryNode with HiveInspectors { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -72,7 +73,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, @@ -200,38 +200,55 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + // scalastyle:off + // ifNotExists is only valid with static partition, refer to + // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries + // scalastyle:on + val oldPart = catalog.synchronized { + catalog.client.getPartition( + catalog.client.getTable(qualifiedTableName), partitionSpec, false) + } + if (oldPart == null || !ifNotExists) { + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8efed7f0299bf..cab0fdd35723a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} -import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} import java.util.Properties import scala.collection.JavaConversions._ @@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -121,14 +121,13 @@ case class ScriptTransformation( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 4345ffbf30f77..a40a1e53117cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** * Analyzes the given table in the current database to generate statistics, which will be @@ -58,12 +58,13 @@ case class DropTable( try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}") + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") @@ -75,11 +76,17 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) - Seq.empty[Row] + Seq(Row(0)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ba2bf67aed684..8398da268174d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -118,19 +117,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -213,7 +200,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) - val colString = + val colString = if (string == null || string.isEmpty) { defaultPartName } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a3497eadd67f6..6570fa1043900 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -262,12 +262,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |WITH SERDEPROPERTIES ('field.delim'='\\t') """.stripMargin.cmd, "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), - TestTable("sales", - s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) - |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' diff --git a/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 new file mode 100644 index 0000000000000..f6ba75da254ca --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 @@ -0,0 +1,3 @@ +5 +5 +5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 new file mode 100644 index 0000000000000..ca7b591095e28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 @@ -0,0 +1,4 @@ +val_4 +val_5 +val_5 +val_5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 new file mode 100644 index 0000000000000..b8626c4cff284 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 @@ -0,0 +1 @@ +4 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 0000000000000..185a91c110d6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 0000000000000..185a91c110d6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea index 25ce912507d55..a1963ba81e0da 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea +++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 index 25ce912507d55..a1963ba81e0da 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 +++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar new file mode 100644 index 0000000000000..37af9aafad8a4 Binary files /dev/null and b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index 0270e63557963..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConversions._ - -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util._ - - -/** - * *** DUPLICATED FROM sql/core. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class QueryTest extends PlanTest { - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") - } - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(rdd, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } -} - -object QueryTest { - /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case o => o - }) - } - if (!isSorted) converted.sortBy(_.toString) else converted - } - val sparkAnswer = try rdd.collect().toSeq catch { - case e: Exception => - val errorMessage = - s""" - |Exception thrown while executing query: - |${rdd.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - - return None - } - - def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(rdd, expectedAnswer.toSeq) match { - case Some(errorMessage) => errorMessage - case None => null - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 98f1c0e69e29d..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.scalatest.FunSuite - -/** - * *** DUPLICATED FROM sql/catalyst/plans. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class PlanTest extends FunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 221a0c263d36c..c188264072a84 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -24,21 +24,6 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 968557c9c4686..d960a30e00738 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -136,7 +136,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { * @param query the query to analyze * @param token a unique token in the string that should be indicated by the exception */ - def positionTest(name: String, query: String, token: String) = { + def positionTest(name: String, query: String, token: String): Unit = { def parseTree = Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3181cfe40016c..2a7374cc172b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -79,9 +79,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1,2.0d,3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -116,21 +116,20 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { - dt1.zip(dt2).map { - case (dd1, dd2) => - assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + dt1.zip(dt2).foreach { case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info } } def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { - row1.zip(row2).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2).foreach { case (r1, r2) => + checkValue(r1, r2) } } def checkValues(row1: Seq[Any], row2: Row): Unit = { - row1.zip(row2.toSeq).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2.toSeq).foreach { case (r1, r2) => + checkValue(r1, r2) } } @@ -141,7 +140,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { assert(r1.compare(r2) === 0) case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => - r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } + r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } case (r1, r2) => assert(r1 === r2) } } @@ -166,7 +165,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = + constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -202,7 +202,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, + unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } @@ -212,8 +213,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -222,7 +225,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada52642..fa8e11ffec2b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT @@ -36,4 +37,11 @@ class HiveMetastoreCatalogSuite extends FunSuite { assert(HiveMetastoreTypes.toMetastoreType(udt) === HiveMetastoreTypes.toMetastoreType(udt.sqlType)) } + + test("duplicated metastore relations") { + import TestHive.implicits._ + val df = TestHive.sql("SELECT * FROM src") + println(df.queryExecution) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 8011952e0d535..ecb990e8aac91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -115,11 +115,36 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + sql( + s""" + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='1') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='2') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='3') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='4') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() val folders = dir.filter(_.isDirectory).toList @@ -196,34 +221,42 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { testData.registerTempTable("testData") val testDatawithNull = TestHive.sparkContext.parallelize( - (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + sql( + s""" + |CREATE TABLE table_with_partition(key int,value string) + |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (ds='1') SELECT key,value FROM testData + """.stripMargin) // test schema the same between partition and table sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // add column to table sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), - testDatawithNull.collect.toSeq + testDatawithNull.collect().toSeq ) // change column name to table sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) sql("DROP TABLE table_with_partition") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e5ad0bf552073..e09c702c8969e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,6 +25,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ @@ -682,6 +684,27 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { assert(schema === actualSchema) } + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala new file mode 100644 index 0000000000000..a787fa5546e76 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.google.common.io.Files + +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils + + +class QueryPartitionSuite extends QueryTest { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + test("SPARK-5068: query data when path doesn't exists"){ + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + + // delete the path of one partition + val folders = tmpDir.listFiles.filter(_.isDirectory) + Utils.deleteRecursively(folders(0)) + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect) + + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1e05a024b8807..00a69de9e4262 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,7 +120,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) @@ -142,7 +142,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { after: () => Unit, query: String, expectedAnswer: Seq[Row], - ct: ClassTag[_]) = { + ct: ClassTag[_]): Unit = { before() var df = sql(query) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 42a82c1fbf5c7..a3f5921a0cb23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class BigDataBenchmarkSuite extends HiveComparisonTest { val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( "rankings", @@ -63,7 +64,7 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | searchWord STRING, | duration INT) | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," - | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), TestTable( "documents", @@ -83,7 +84,10 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") createQueryTest("query2", - "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + """ + |SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits + |GROUP BY SUBSTR(sourceIP, 1, 10) + """.stripMargin) createQueryTest("query3", """ @@ -113,8 +117,8 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { |CREATE TABLE url_counts_total AS | SELECT SUM(count) AS totalCount, destpage | FROM url_counts_partial GROUP BY destpage - |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic - |-- given different input splits. + |-- The following queries run, but generate different results in HIVE + |-- likely because the UDF is not deterministic given different input splits. |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial |-- SELECT COUNT(*) FROM url_counts_partial |-- SELECT * FROM url_counts_partial diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8f3285242091c..027056d4b865f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,7 +138,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -255,8 +255,9 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") .filterNot(_ contains "hive.exec.post.hooks") - if (allQueries != queryList) + if (allQueries != queryList) { logWarning(s"Simplifications made on unsupported operations for test $testCaseName") + } lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -299,19 +300,22 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.logical) + hiveQueries.foreach(_.analyzed) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. - if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) { sys.error("hive exec hooks not supported for tests.") + } - logWarning(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i + 1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { - case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _: ExplainCommand => + // No need to execute EXPLAIN queries as we don't check the output. + Nil case _ => TestHive.runSqlHive(queryString) } @@ -394,21 +398,24 @@ abstract class HiveComparisonTest case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => if (System.getProperty("spark.hive.canarytest") != null) { - // When we encounter an error we check to see if the environment is still okay by running a simple query. - // If this fails then we halt testing since something must have gone seriously wrong. + // When we encounter an error we check to see if the environment is still + // okay by running a simple query. If this fails then we halt testing since + // something must have gone seriously wrong. try { new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") - // The testing setup traps exits so wait here for a long time so the developer can see when things started - // to go wrong. + logError(s"FATAL ERROR: Canary query threw $e This implies that the " + + "testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer + // can see when things started to go wrong. Thread.sleep(1000000) } } - // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + // If the canary query didn't fail then the environment is still okay, + // so just throw the original exception. throw originalException } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28a..bdb53ddf59c19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 02518d516261b..f7b37dae0a5f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. - * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * that should be included. Additionally, there is support for whitelisting and blacklisting + * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ @@ -54,15 +55,17 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. - if(System.getProperty(whiteListProperty) == null && !runAll) + if (System.getProperty(whiteListProperty) == null && !runAll) { ignore(testCaseName) {} + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index de140fc72a2c3..300b1f7920473 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault @@ -237,7 +238,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } createQueryTest("modulus", - "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + + "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") @@ -309,7 +311,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT * FROM src a JOIN src b ON a.key = b.key") createQueryTest("small.cartesian", - "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN " + + "(SELECT key FROM src WHERE key = 2) b") createQueryTest("length.udf", "SELECT length(\"test\") FROM src LIMIT 1") @@ -457,6 +460,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + // scalastyle:off createQueryTest("lateral view4", """ |create table src_lv1 (key string, value string); @@ -466,6 +470,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX """.stripMargin) + // scalastyle:on createQueryTest("lateral view5", "FROM src SELECT explode(array(key+3, key+4))") @@ -537,6 +542,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("select null from table", "SELECT null FROM src LIMIT 1") + createQueryTest("CTE feature #1", + "with q1 as (select key from src) select * from q1 where key = 5") + + createQueryTest("CTE feature #2", + """with q1 as (select * from src where key= 5), + |q2 as (select * from src s2 where key = 4) + |select value from q1 union all select value from q2 + """.stripMargin) + + createQueryTest("CTE feature #3", + """with q1 as (select key from src) + |from q1 + |select * where key = 4 + """.stripMargin) + test("predicates contains an empty AttributeSet() references") { sql( """ @@ -584,7 +604,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: DataFrame) = { + def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } @@ -793,6 +813,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE alter1") } + test("ADD JAR command 2") { + // this is a test case from mapjoin_addjar.q + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + if (HiveShim.version == "0.13.1") { + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index f4440e5b7846a..8ad3627504229 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -25,7 +25,8 @@ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7486bfa82b00b..5586a793618bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -25,17 +25,25 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validates support for Hive SerDe. */ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { - - override def beforeAll() = { + override def beforeAll(): Unit = { + import TestHive._ + import org.apache.hadoop.hive.serde2.RegexSerDe + super.beforeAll() TestHive.cacheTables = false + sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin) + sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } + // table sales is not a cache table, and will be clear after reset + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales", false) + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") - createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") - createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index ab0e0443c7faa..f0f04f8c73fb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,8 +35,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { val nullVal = "null" baseTypes.init.foreach { i => - createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") - createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + createQueryTest(s"case when then $i else $nullVal end ", + s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", + s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index d7c5d1a25a82b..7f49eac490572 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -123,9 +123,10 @@ class HiveUdfSuite extends QueryTest { IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") @@ -141,7 +142,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") @@ -156,7 +157,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") @@ -170,7 +171,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") @@ -187,7 +188,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") @@ -247,7 +248,8 @@ class PairUdf extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6c..067b577f1560e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -143,7 +143,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { sql: String, expectedOutputColumns: Seq[String], expectedScannedColumns: Seq[String], - expectedPartValues: Seq[Seq[String]]) = { + expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1187228f4c3db..47b4cb9ca61ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,15 +34,122 @@ case class Nested3(f3: Int) case class NestedArray2(b: Seq[Int]) case class NestedArray1(a: NestedArray2) +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-6835: udtf in lateral view") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } + + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } test("explode nested Field") { - Seq(NestedArray1(NestedArray2(Seq(1,2,3)))).toDF.registerTempTable("nestedArray") + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) @@ -398,7 +505,7 @@ class SQLQuerySuite extends QueryTest { } test("resolve udtf with single alias") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") @@ -411,7 +518,7 @@ class SQLQuerySuite extends QueryTest { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val originalConf = getConf("spark.sql.hive.convertCTAS", "false") setConf("spark.sql.hive.convertCTAS", "false") @@ -433,4 +540,25 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 432d65a874518..d5dd0bf58e702 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -26,8 +25,10 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.SaveMode @@ -390,6 +391,114 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { @@ -578,6 +687,22 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { sql("DROP TABLE alwaysNullable") } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.saveAsParquetFile(filePath2) + val df4 = parquetFile(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } } class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { @@ -761,7 +886,11 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"SPARK-5775 read struct from $table") { checkAnswer( - sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"), + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), (1 to 10).map(i => Row(1, i, f"${i}_string"))) } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 0ed93c2c5b1fa..33e96eaabfbf6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} private[hive] case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { @@ -135,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7577309900209..d331c210e8939 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,37 +17,35 @@ package org.apache.spark.sql.hive -import java.util -import java.util.{ArrayList => JArrayList} -import java.util.Properties import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** * This class provides the UDF creation and also the UDF instance serialization and @@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method @@ -224,7 +218,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/streaming/pom.xml b/streaming/pom.xml index 23a8358d45c2a..5ca55a4f680bb 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -97,34 +97,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - org.apache.maven.plugins maven-shade-plugin diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f73b463d07779..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -139,8 +139,11 @@ class CheckpointWriter( // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) - fos.write(bytes) - fos.close() + Utils.tryWithSafeFinally { + fos.write(bytes) + } { + fos.close() + } // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail @@ -187,9 +190,11 @@ class CheckpointWriter( val bos = new ByteArrayOutputStream() val zos = compressionCodec.compressedOutputStream(bos) val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } try { executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) @@ -234,7 +239,7 @@ object CheckpointReader extends Logging { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! - def fs = checkpointPath.getFileSystem(hadoopConf) + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse @@ -248,18 +253,24 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() + var ois: ObjectInputStreamWithLoader = null + var cp: Checkpoint = null + Utils.tryWithSafeFinally { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + cp = ois.readObject.asInstanceOf[Checkpoint] + } { + if (ois != null) { + ois.close() + } + } cp.validate() logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 73030e15c5661..808dcc174cf9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -169,7 +169,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -179,7 +179,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -190,7 +190,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -201,7 +203,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index f94f2d0e8bd31..93baad19e3ee1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -526,7 +526,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index e3db01c1e12c6..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -192,7 +192,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -313,7 +313,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } @@ -344,7 +344,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } @@ -625,7 +625,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -633,7 +633,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 795c5aa6d585b..24f99a2b929f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -839,7 +839,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 42514d8b47dcf..f4963a78e1d18 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{SystemClock, Utils} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -79,9 +79,9 @@ private[streaming] class BlockGenerator( private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) private val clock = new SystemClock() - private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") private val blockIntervalTimer = - new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator") + new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @@ -132,7 +132,7 @@ private[streaming] class BlockGenerator( val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[Any] if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockInterval) + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) val newBlock = new Block(blockId, newBlockBuffer) listener.onGenerateBlock(blockId) blocksForPushing.put(newBlock) // put is blocking when queue is full diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 4946806d2ee95..58e56638a2dca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -24,7 +24,7 @@ import akka.actor.{ActorRef, Props, Actor} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, ManualClock} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -104,17 +104,15 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") val timeWhenStopStarted = System.currentTimeMillis() - val stopTimeout = conf.getLong( - "spark.streaming.gracefulStopTimeout", - 10 * ssc.graph.batchDuration.milliseconds - ) + val stopTimeoutMs = conf.getTimeAsMs( + "spark.streaming.gracefulStopTimeout", s"${10 * ssc.graph.batchDuration.milliseconds}ms") val pollTime = 100 // To prevent graceful stop to get stuck permanently def hasTimedOut: Boolean = { - val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeoutMs if (timedOut) { - logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") + logWarning("Timed out while stopping the job generator (timeout = " + stopTimeoutMs + ")") } timedOut } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index d6a93acbe711b..95f1857b4c377 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -105,6 +105,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (jobSet.jobs.isEmpty) { logInfo("No jobs added for time " + jobSet.time) } else { + listenerBus.post(StreamingListenerBatchSubmitted(jobSet.toBatchInfo)) jobSets.put(jobSet.time, jobSet) jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + jobSet.time) @@ -134,10 +135,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) - if (!jobSet.hasStarted) { + val isFirstJobOfJobSet = !jobSet.hasStarted + jobSet.handleJobStart(job) + if (isFirstJobOfJobSet) { + // "StreamingListenerBatchStarted" should be posted after calling "handleJobStart" to get the + // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala new file mode 100644 index 0000000000000..df1c0a10704c3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import scala.xml.Node + +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.ui.UIUtils + +private[ui] abstract class BatchTableBase(tableId: String) { + + protected def columns: Seq[Node] = { + Batch Time + Input Size + Scheduling Delay + Processing Time + } + + protected def baseRow(batch: BatchInfo): Seq[Node] = { + val batchTime = batch.batchTime.milliseconds + val formattedBatchTime = UIUtils.formatDate(batch.batchTime.milliseconds) + val eventCount = batch.receivedBlockInfo.values.map { + receivers => receivers.map(_.numRecords).sum + }.sum + val schedulingDelay = batch.schedulingDelay + val formattedSchedulingDelay = schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") + val processingTime = batch.processingDelay + val formattedProcessingTime = processingTime.map(UIUtils.formatDuration).getOrElse("-") + + {formattedBatchTime} + {eventCount.toString} events + + {formattedSchedulingDelay} + + + {formattedProcessingTime} + + } + + private def batchTable: Seq[Node] = { + + + {columns} + + + {renderRows} + +
    + } + + def toNodeSeq: Seq[Node] = { + batchTable + } + + /** + * Return HTML for all rows of this table. + */ + protected def renderRows: Seq[Node] +} + +private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatches: Seq[BatchInfo]) + extends BatchTableBase("active-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Status + + override protected def renderRows: Seq[Node] = { + // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display + // waiting batches before running batches + waitingBatches.flatMap(batch => {waitingBatchRow(batch)}) ++ + runningBatches.flatMap(batch => {runningBatchRow(batch)}) + } + + private def runningBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ processing + } + + private def waitingBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ queued + } +} + +private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) + extends BatchTableBase("completed-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Total Delay + + override protected def renderRows: Seq[Node] = { + batches.flatMap(batch => {completedBatchRow(batch)}) + } + + private def completedBatchRow(batch: BatchInfo): Seq[Node] = { + val totalDelay = batch.totalDelay + val formattedTotalDelay = totalDelay.map(UIUtils.formatDuration).getOrElse("-") + baseRow(batch) ++ + + {formattedTotalDelay} + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index e4bd067cacb77..be1e8686cf9fa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -33,7 +33,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private val waitingBatchInfos = new HashMap[Time, BatchInfo] private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedaBatchInfos = new Queue[BatchInfo] + private val completedBatchInfos = new Queue[BatchInfo] private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L @@ -62,7 +62,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + waitingBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo } } @@ -79,8 +79,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) synchronized { waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + completedBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedBatchInfos.size > batchInfoLimit) completedBatchInfos.dequeue() totalCompletedBatches += 1L batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => @@ -118,7 +118,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedaBatchInfos.toSeq + completedBatchInfos.toSeq } def processingDelayDistribution: Option[Distribution] = synchronized { @@ -149,7 +149,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) }.toMap } - def lastReceivedBatchRecords: Map[Int, Long] = { + def lastReceivedBatchRecords: Map[Int, Long] = synchronized { val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => (0 until numReceivers).map { receiverId => @@ -160,24 +160,24 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = synchronized { receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = { - completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchInfo] = synchronized { + completedBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = { + def lastReceivedBatch: Option[BatchInfo] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = synchronized { + private def retainedBatches: Seq[BatchInfo] = { (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + runningBatchInfos.values.toSeq ++ completedBatchInfos).sortBy(_.batchTime)(Time.ordering) } private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + Distribution(completedBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index bfe8086fcf8fe..07fa285642eec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -37,20 +37,23 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val content = + val content = listener.synchronized { generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ - generateBatchStatsTable() + generateBatchStatsTable() ++ + generateBatchListTables() + } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ private def generateBasicStats(): Seq[Node] = { val timeSinceStart = System.currentTimeMillis() - startTime + // scalastyle:off
    • - Started at: {startTime.toString} + Started at: {UIUtils.formatDate(startTime)}
    • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -62,18 +65,19 @@ private[ui] class StreamingPage(parent: StreamingTab) Batch interval: {formatDurationVerbose(listener.batchDuration)}
    • - Processed batches: {listener.numTotalCompletedBatches} + Completed batches: {listener.numTotalCompletedBatches}
    • - Waiting batches: {listener.numUnprocessedBatches} + Active batches: {listener.numUnprocessedBatches}
    • - Received records: {listener.numTotalReceivedRecords} + Received events: {listener.numTotalReceivedRecords}
    • - Processed records: {listener.numTotalProcessedRecords} + Processed events: {listener.numTotalProcessedRecords}
    + // scalastyle:on } /** Generate stats of data received by the receivers in the streaming program */ @@ -85,10 +89,10 @@ private[ui] class StreamingPage(parent: StreamingTab) "Receiver", "Status", "Location", - "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", - "Minimum rate\n[records/sec]", - "Median rate\n[records/sec]", - "Maximum rate\n[records/sec]", + "Events in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", + "Minimum rate\n[events/sec]", + "Median rate\n[events/sec]", + "Maximum rate\n[events/sec]", "Last Error" ) val dataRows = (0 until listener.numReceivers).map { receiverId => @@ -189,5 +193,26 @@ private[ui] class StreamingPage(parent: StreamingTab) } UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } + + private def generateBatchListTables(): Seq[Node] = { + val runningBatches = listener.runningBatches.sortBy(_.batchTime.milliseconds).reverse + val waitingBatches = listener.waitingBatches.sortBy(_.batchTime.milliseconds).reverse + val completedBatches = listener.retainedCompletedBatches. + sortBy(_.batchTime.milliseconds).reverse + + val activeBatchesContent = { +

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ + new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + } + + val completedBatchesContent = { +

    + Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) +

    ++ + new CompletedBatchTable(completedBatches).toNodeSeq + } + + activeBatchesContent ++ completedBatchesContent + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd612..ca2f319f174a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -72,7 +72,8 @@ object RawTextSender extends Logging { } catch { case e: IOException => logError("Client disconnected") - socket.close() + } finally { + socket.close() } } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cf191715d29d6..87bc20f79c3cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 91a2b2bba461d..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { @@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 26435d8515815..0c4c06534a693 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging { val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) + override def batchDuration: Duration = Milliseconds(1000) - override def useManualClock = false + override def useManualClock: Boolean = false override def afterFunction() { Utils.deleteRecursively(directory) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7ed6320a3d0bc..e6ac4975c5e68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time @@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f92094d..c090eaec2928d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } @@ -99,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -123,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 42fad769f0c1a..b63b37d9f9cef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index aa20ad0b5374e..91261a9db7360 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -131,11 +131,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) + val blockIntervalMs = 200 + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 5 - val waitTime = expectedBlocks * blockInterval + (blockInterval / 2) + val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -157,15 +157,15 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 100 + val blockIntervalMs = 100 val maxRate = 100 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 20 - val waitTime = expectedBlocks * blockInterval + val waitTime = expectedBlocks * blockIntervalMs val expectedMessages = maxRate * waitTime / 1000 - val expectedMessagesPerBlock = maxRate * blockInterval / 1000 + val expectedMessagesPerBlock = maxRate * blockIntervalMs / 1000 val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 2e5005ef6ff14..58353a5f97c8a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -73,9 +73,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from existing SparkContext") { @@ -85,24 +85,26 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) - assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") + assert( + Utils.timeStringAsSeconds(cp.sparkConfPairs + .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.createSparkConf().getInt("spark.cleaner.ttl", -1) === 10) + assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("start and stop state check") { @@ -176,7 +178,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600") + conf.set("spark.cleaner.ttl", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") @@ -207,13 +209,13 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.streaming.gracefulStopTimeout", "20000") + conf.set("spark.streaming.gracefulStopTimeout", "20000s") sc = new SparkContext(conf) logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 SlowTestReceiver.receivedAllRecords = false - //Create test receiver that sleeps in onStop() + // Create test receiver that sleeps in onStop() val totalNumRecords = 15 val recordsPerSecond = 1 val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) @@ -370,7 +372,8 @@ object TestReceiver { } /** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ -class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { var receivingThreadOption: Option[Thread] = None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f73..7210439509541 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,18 +38,46 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) - val batchInfos = collector.batchInfos - batchInfos should have size 4 - batchInfos.foreach(info => { + // SPARK-6766: batch info should be submitted + val batchInfosSubmitted = collector.batchInfosSubmitted + batchInfosSubmitted should have size 4 + + batchInfosSubmitted.foreach(info => { + info.schedulingDelay should be (None) + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + + // SPARK-6766: processingStartTime of batch info should not be None when starting + val batchInfosStarted = collector.batchInfosStarted + batchInfosStarted should have size 4 + + batchInfosStarted.foreach(info => { + info.schedulingDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + + // test onBatchCompleted + val batchInfosCompleted = collector.batchInfosCompleted + batchInfosCompleted should have size 4 + + batchInfosCompleted.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -58,9 +86,9 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -99,9 +127,20 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfos = new ArrayBuffer[BatchInfo] + val batchInfosCompleted = new ArrayBuffer[BatchInfo] + val batchInfosStarted = new ArrayBuffer[BatchInfo] + val batchInfosSubmitted = new ArrayBuffer[BatchInfo] + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { + batchInfosSubmitted += batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { + batchInfosStarted += batchStarted.batchInfo + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfos += batchCompleted.batchInfo + batchInfosCompleted += batchCompleted.batchInfo } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3565d621e8a6c..c3cae8aeb6d15 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) ssc.awaitTerminationOrTimeout(50) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 87a0395efbf2a..205ddf6dbe9b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark._ /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ @@ -74,6 +75,17 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq statisticText should contain("Network receivers:") statisticText should contain("Batch interval:") + + val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq + h4Text should contain("Active Batches (0)") + h4Text should contain("Completed Batches (last 0 out of 0)") + + findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Status") + } + findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Total Delay") + } } ssc.stop(false) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16c..c39ad05f41520 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577dd..c3602a5b73732 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 4150b60635ed6..7865b06c2e3c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase { val receiverTracker = ssc.scheduler.receiverTracker // Get the blocks belonging to a batch - def getBlocksOfBatch(batchTime: Long) = { + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala new file mode 100644 index 0000000000000..94b1985116feb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.scalatest.Matchers + +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} + +class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + override def batchDuration: Duration = Milliseconds(100) + + test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + + "onReceiverStarted, onReceiverError, onReceiverStopped") { + val ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + listener.waitingBatches should be (List(batchInfoSubmitted)) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (0) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (List(batchInfoStarted)) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (600) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (List(batchInfoCompleted)) + listener.lastCompletedBatch should be (Some(batchInfoCompleted)) + listener.numUnprocessedBatches should be (0) + listener.numTotalCompletedBatches should be (1) + listener.numTotalProcessedRecords should be (600) + listener.numTotalReceivedRecords should be (600) + + // onReceiverStarted + val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (None) + + // onReceiverError + val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (None) + + // onReceiverStopped + val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (Some(receiverInfoStopped)) + listener.receiverInfo(3) should be (None) + } + + test("Remove the old completed batches when exceeding the limit") { + val ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + + for(_ <- 0 until (limit + 10)) { + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.retainedCompletedBatches.size should be (limit) + listener.numTotalCompletedBatches should be(limit + 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8335659667f22..a3919c43b95b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -291,7 +291,7 @@ object WriteAheadLogSuite { manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74d..d66750463033a 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d3..583823c90c5c6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d3..f2d135397ce2f 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -58,8 +58,8 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { @@ -78,7 +78,7 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d18690cd9cbf..c357b7ae9d4da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -162,7 +160,7 @@ private[spark] class ApplicationMaster( * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { + final def getDefaultFinalStatus(): FinalApplicationStatus = { if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { @@ -175,31 +173,35 @@ private[spark] class ApplicationMaster( * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = Utils.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } } } } @@ -221,6 +223,7 @@ private[spark] class ApplicationMaster( val appId = client.getAttemptId().getApplicationId().toString() val historyAddress = sparkConf.getOption("spark.yarn.historyServer.address") + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } .getOrElse("") @@ -236,22 +239,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -268,8 +270,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -279,8 +281,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -295,7 +296,7 @@ private[spark] class ApplicationMaster( // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") // must be <= expiryInterval / 2. val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) @@ -378,7 +379,8 @@ private[spark] class ApplicationMaster( logWarning( "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") } - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", waitTries.getOrElse(100000L)) + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", + s"${waitTries.getOrElse(100000L)}ms") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { @@ -403,8 +405,8 @@ private[spark] class ApplicationMaster( // Spark driver should already be up since it launched us, but we don't want to // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", 100000L) - val deadline = System.currentTimeMillis + totalWaitTime + val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val deadline = System.currentTimeMillis + totalWaitTimeMs while (!driverUp && !finished && System.currentTimeMillis < deadline) { try { @@ -427,7 +429,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -439,7 +441,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -469,6 +471,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) @@ -501,44 +506,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart() = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + } + override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -546,7 +536,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } @@ -567,7 +566,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -576,11 +575,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -592,7 +591,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae7..ae6dc1094d724 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null + var primaryRFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 @@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--py-files") :: value :: tail => pyFiles = value args = tail @@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 61f8fc3f5a014..1091ff54b0463 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -22,17 +22,21 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -40,6 +44,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} @@ -66,6 +71,8 @@ private[spark] class Client( private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) def stop(): Unit = yarnClient.stop() @@ -217,6 +224,7 @@ private[spark] class Client( val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) + obtainTokenForHiveMetastore(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -488,6 +496,12 @@ private[spark] class Client( } else { Nil } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName @@ -497,12 +511,15 @@ private[spark] class Client( if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) @@ -559,36 +576,25 @@ private[spark] class Client( var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) - val report = getApplicationReport(appId) + val report: ApplicationReport = + try { + getApplicationReport(appId) + } catch { + case e: ApplicationNotFoundException => + logError(s"Application $appId not found.") + return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + } val state = report.getYarnApplicationState if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } @@ -609,24 +615,57 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } } @@ -902,6 +941,64 @@ object Client extends Logging { } } + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null) + + val hiveConf = hiveClass.getMethod("getConf").invoke(hive) + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + + val hiveConfGet = (param:String) => Option(hiveConfClass + .getMethod("get", classOf[java.lang.String]) + .invoke(hiveConf, param)) + + val metastore_uri = hiveConfGet("hive.metastore.uris") + + // Check for local metastore + if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val metastore_kerberos_principal_conf_var = mirror.classLoader + .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") + .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString + + val principal = hiveConfGet(metastore_kerberos_principal_conf_var) + + val username = Option(UserGroupInformation.getCurrentUser().getUserName) + if (principal != None && username != None) { + val tokenStr = hiveClass.getMethod("getDelegationToken", + classOf[java.lang.String], classOf[java.lang.String]) + .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] + + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + credentials.addToken(new Text("hive.server2.delegation.token"),hive2Token) + logDebug("Added hive.Server2.delegation.token to conf.") + hiveClass.getMethod("closeCurrent").invoke(null) + } else { + logError("Username or principal == NULL") + logError(s"""username=${username.getOrElse("(NULL)")}""") + logError(s"""principal=${principal.getOrElse("(NULL)")}""") + throw new IllegalArgumentException("username and/or principal is equal to null!") + } + } else { + logDebug("HiveMetaStore configured in localmode") + } + } catch { + case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e:Exception => { logError("Unexpected Exception " + e) + throw new RuntimeException("Unexpected exception", e) + } + } + } + } + /** * Return whether the two file systems are the same. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf341..da6798cb1b279 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userClass: String = null var pyFiles: String = null var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c1d3f7320f53c..b06069c07f451 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -59,15 +59,15 @@ class ExecutorRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) lazy val env = prepareEnvironment(container) - def run = { + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) @@ -290,10 +290,19 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + // Add log urls sys.env.get("SPARK_USER").foreach { user => - val baseUrl = "http://%s/node/containerlogs/%s/%s" - .format(container.getNodeHttpAddress, ConverterUtils.toString(container.getId), user) + val containerId = ConverterUtils.toString(container.getId) + val address = container.getNodeHttpAddress + val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b58f..b8f42dadcb464 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 8abdc26b43806..99c05329b4d73 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - @volatile private var stopping: Boolean = false + private var monitorThread: Thread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -57,7 +57,8 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.submitApplication() waitForApplication() - asyncMonitorApplication() + monitorThread = asyncMonitorApplication() + monitorThread.start() } /** @@ -123,34 +124,22 @@ private[spark] class YarnClientSchedulerBackend( * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Unit = { + private def asyncMonitorApplication(): Thread = { assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { - while (!stopping) { - var state: YarnApplicationState = null - try { - val report = client.getApplicationReport(appId) - state = report.getYarnApplicationState() - } catch { - case e: ApplicationNotFoundException => - state = YarnApplicationState.KILLED - } - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.KILLED || - state == YarnApplicationState.FAILED) { - logError(s"Yarn application has already exited with state $state!") - sc.stop() - stopping = true - } - Thread.sleep(1000L) + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") } - Thread.currentThread().interrupt() } } t.setName("Yarn application state monitor") t.setDaemon(true) - t.start() + t } /** @@ -158,7 +147,7 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - stopping = true + monitorThread.interrupt() super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index aab41fa49430f..6b8a5dbf6373e 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 92f04b4b859b3..c1b94ac9c5bdd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -232,19 +232,26 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { testCode(conf) } - def newEnv = MutableHashMap[String, String]() + def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;|") + def classpath(env: MutableHashMap[String, String]): Array[String] = + env(Environment.CLASSPATH.name).split(":|;|") - def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] = + (a ++ b).flatten.toArray - def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = - Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = { + Try(clazz.getField(field)) + .map(_.get(null).asInstanceOf[A]) + .toOption + .map(mapTo) + .getOrElse(defaults) + } def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index c09b01bafce37..455f1019d86dd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -79,7 +79,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { - override def equals(other: Any) = false + override def equals(other: Any): Boolean = false } def createAllocator(maxExecutors: Int = 5): YarnAllocator = { @@ -118,7 +118,9 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + + val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size + size should be (0) } test("some containers allocated") { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 0e37276ba724b..a18c94d4ab4a8 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -143,6 +143,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } + // Enable this once fix SPARK-6700 test("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) @@ -281,10 +282,10 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } -private class SaveExecutorInfo extends SparkListener { +private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor : SparkListenerExecutorAdded) { + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } } @@ -292,7 +293,6 @@ private class SaveExecutorInfo extends SparkListener { private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 - var listener: SaveExecutorInfo = null def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -305,10 +305,9 @@ private object YarnClusterDriver extends Logging with Matchers { System.exit(1) } - listener = new SaveExecutorInfo val sc = new SparkContext(new SparkConf() + .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - sc.addSparkListener(listener) val status = new File(args(0)) var result = "failure" try { @@ -322,7 +321,12 @@ private object YarnClusterDriver extends Logging with Matchers { } // verify log urls are present - listener.addedExecutorInfos.values.foreach { info => + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + val executorInfos = listener.addedExecutorInfos.values + assert(executorInfos.nonEmpty) + executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 4194f36499e66..9395316b71ff4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -46,7 +46,7 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { logWarning("Cannot execute bash, skipping bash tests.") } - def bashTest(name: String)(fn: => Unit) = + def bashTest(name: String)(fn: => Unit): Unit = if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") {