diff --git a/.gitignore b/.gitignore index d162fa9cca994..d54d21b802be8 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 8c61e67a0c7d1..8aca5a7f7a967 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -67,3 +67,5 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +DESCRIPTION +NAMESPACE diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000000000..9a5889ba28b2a --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 0000000000000..931d01549b265 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000000000..a6970e39b55f3 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 0000000000000..3f889c0ca3d1e --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 0000000000000..4194172a2e115 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 0000000000000..008a5c668bc45 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 0000000000000..55ed6f4be1a4a --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 0000000000000..701adb2a3da1d --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 0000000000000..1842b97d43651 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'SQLTypes.R' + 'RDD.R' + 'pairRDD.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'broadcast.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'backend.R' + 'client.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 0000000000000..a354cdce74afa --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,182 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "subtract", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("print.structType", + "print.structField") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 0000000000000..feafd56909a67 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column") + cols <- columns(x) + if (name %in% cols) { + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' Subtract +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the subtract operation. +#' @rdname subtract +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' subtractDF <- subtract(df, df2) +#' } +setMethod("subtract", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + subtracted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(subtracted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 0000000000000..820027ef67e3b --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1531 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- cleanClosure(func) + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(split, iterator) { + func(split, prev@func(split, iterator)) + } + .Object@func <- cleanClosure(pipelinedFunc) + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(rdd@func, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "splits") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(split, iterator) { + lapply(iterator, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { +#' split * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + PipelinedRDD(X, FUN) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = getSerializedMode(x)) + # TODO: Check if this append is O(n^2)? + resList <- append(resList, elems) + } + resList + }) + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(split) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numToInt(numPartitions), TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(s, part) { + set.seed(s) # split as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the splits of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + list(part[ord[1:num]]) + } else { + list(part) + } + } + + reduceFunc <- function(elems, part) { + newElems <- append(elems, part) + # R limitation: order works only on primitive types! + ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) + newElems[ord[1:num]] + } + + newRdd <- mapPartitions(x, partitionFunc) + reduce(newRdd, reduceFunc) +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(split, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + split) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(split, part) { + if (split == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[split]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # check if corresponding partitions of both RDDs have the same number of elements. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + part[[length(part) + 1]] <- length(part) + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + + zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) + # The zippedRDD's elements are of scala Tuple2 type. The serialized + # flag Here is used for the elements inside the tuples. + serializerMode <- getSerializedMode(x) + zippedRDD <- RDD(zippedJRDD, serializerMode) + + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # check if corresponding partitions of both RDDs have the same number of elements. + if (lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + keys <- list() + values <- list() + } + } else { + # Keys, values must have same length here, because this has + # been validated inside the JavaRDD.zip() function. + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { + list(k, v) + }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(zippedRDD, partitionFunc) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 0000000000000..930ada22f4c38 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,520 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R new file mode 100644 index 0000000000000..962fba5b3cf03 --- /dev/null +++ b/R/pkg/R/SQLTypes.R @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions for handling SparkSQL DataTypes. + +# Handler for StructType +structType <- function(st) { + obj <- structure(new.env(parent = emptyenv()), class = "structType") + obj$jobj <- st + obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } + obj +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + fieldsList <- lapply(x$fields(), function(i) { i$print() }) + print(fieldsList) +} + +# Handler for StructField +structField <- function(sf) { + obj <- structure(new.env(parent = emptyenv()), class = "structField") + obj$jobj <- sf + obj$name <- function() { callJMethod(sf, "name") } + obj$dataType <- function() { callJMethod(sf, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(sf, "nullable") } + obj$print <- function() { paste("StructField(", + paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), + ")", sep = "") } + obj +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat(x$print()) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 0000000000000..2fb6fae55f28c --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 0000000000000..583fa2e7fdcfd --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 0000000000000..1281c41213e32 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 0000000000000..e196305186b9a --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 0000000000000..ebbb8fba1052d --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinSplits <- function(sc, minSplits) { + if (is.null(minSplits)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minSplits <- min(defaultParallelism, 2) + } + as.integer(minSplits) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 0000000000000..257b435607ce8 --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 0000000000000..5fb1ccaa84ee2 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") +}) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname subtract +#' @export +setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 0000000000000..09fc0a7abe48a --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 0000000000000..4180f146b7fbc --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 0000000000000..739d399f0820f --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,785 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Array[Byte], + # Array[Byte])], where the key is the hashed split, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + as.integer(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + as.integer(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "integer"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "integer"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "integer"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 0000000000000..8a9c0c652ce24 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + writeInt(con, as.integer(nchar(value) + 1)) + writeBin(value, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} + +writeStrings <- function(con, stringList) { + writeLines(unlist(stringList), con) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 0000000000000..bc82df01f0fff --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 0000000000000..c337fb0751e72 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,467 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..80d796d467943 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 0000000000000..8fe711b622086 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 0000000000000..7a7f2031152a0 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 0000000000000..ca4218f3819f8 --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 0000000000000..c15553ba28517 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 0000000000000..fee91a427d6d5 --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 0000000000000..e4aab37436a74 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 0000000000000..8152b448d0870 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 0000000000000..fff028657db37 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 0000000000000..b76e4db03e715 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,645 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(split, part) { + part <- as.list(unlist(part) * split + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(split, part) { + part <- as.list(unlist(part) * split) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + tempDir <- tempfile(pattern = "checkpoint") + setCheckpointDir(sc, tempDir) + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink(tempDir) +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(splitIndex, part) { list(splitIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 0000000000000..d1da8232aea81 --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 0000000000000..cf5cf6d1692af --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,695 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), subtract(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + subtracted <- sortDF(subtract(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(subtracted) == 2) + expect_true(first(subtracted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 0000000000000..7f4c7c315d787 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 0000000000000..6b87b4b3e0b08 --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1L) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1L) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 0000000000000..9c5bb427932b4 --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 0000000000000..3584b418a71a9 --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 0000000000000..c6542928e8ddd --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + output <- computeFunc(partition, data) + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + SparkR:::writeStrings(outputCon, output) + } + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + } +} + +# End of output +if (serializer %in% c("byte", "row")) { + SparkR:::writeInt(outputCon, 0L) +} + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 0000000000000..a55a56fe80e10 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 0000000000000..aa486d8228371 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 0000000000000..e3274b9a0c547 --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 0000000000000..4f8a1ed2d83ef --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 0000000000000..e82ad0ba2cd06 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 2d7070c25d328..95779e9ddbb18 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -41,8 +42,8 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$SPARK_HOME/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$SPARK_HOME/assembly/target/scala-2.10" + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 0000000000000..8c918e2b09aef --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 0000000000000..d7b60183ca8e0 --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 0000000000000..e47f22c7300bb --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f61..3a2a88219818f 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index 6cd1965ec37c2..e80829b7a7f3d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -442,4 +442,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..3a2a88219818f 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 9385f557c4614..4e7bf51fc0622 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -80,16 +80,16 @@ private[spark] class ExecutorAllocationManager( Integer.MAX_VALUE) // How long there must be backlogged tasks for before an addition is triggered (seconds) - private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 5) + private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") - // Same as above, but used only after `schedulerBacklogTimeout` is exceeded - private val sustainedSchedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded + private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", s"${schedulerBacklogTimeoutS}s") // How long an executor must be idle for before it is removed (seconds) - private val executorIdleTimeout = conf.getLong( - "spark.dynamicAllocation.executorIdleTimeout", 600) + private val executorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.executorIdleTimeout", "600s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -150,14 +150,14 @@ private[spark] class ExecutorAllocationManager( throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } - if (schedulerBacklogTimeout <= 0) { + if (schedulerBacklogTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") } - if (sustainedSchedulerBacklogTimeout <= 0) { + if (sustainedSchedulerBacklogTimeoutS <= 0) { throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeout <= 0) { + if (executorIdleTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") } // Require external shuffle service for dynamic allocation @@ -262,8 +262,8 @@ private[spark] class ExecutorAllocationManager( } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") + addTime += sustainedSchedulerBacklogTimeoutS * 1000 delta } else { 0 @@ -351,7 +351,7 @@ private[spark] class ExecutorAllocationManager( val removeRequestAcknowledged = testing || client.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeoutS seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -407,8 +407,8 @@ private[spark] class ExecutorAllocationManager( private def onSchedulerBacklogged(): Unit = synchronized { if (addTime == NOT_SET) { logDebug(s"Starting timer to add executors because pending tasks " + - s"are building up (to expire in $schedulerBacklogTimeout seconds)") - addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + s"are building up (to expire in $schedulerBacklogTimeoutS seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeoutS * 1000 } } @@ -431,8 +431,8 @@ private[spark] class ExecutorAllocationManager( if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 5871b8c869f03..e3bd16f1cbf24 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -62,14 +62,17 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val executorTimeoutMs = sc.conf.getOption("spark.network.timeout").map(_.toLong * 1000). - getOrElse(sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120000)) - + private val slaveTimeoutMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") + private val executorTimeoutMs = + sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val checkTimeoutIntervalMs = - sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). - getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) + private val timeoutIntervalMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") + private val checkTimeoutIntervalMs = + sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 private var timeoutCheckingTask: ScheduledFuture[_] = null diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 09a9ccc226721..8de3a6c04df34 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -160,7 +160,7 @@ private[spark] class HttpServer( throw new ServerStateException("Server is not started") } else { val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" - s"$scheme://${Utils.localIpAddress}:$port" + s"$scheme://${Utils.localHostNameForURI()}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 0c123c96b8d7b..390e631647bd6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -174,6 +174,42 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then seconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsSeconds(key: String): Long = { + Utils.timeStringAsSeconds(get(key)) + } + + /** + * Get a time parameter as seconds, falling back to a default if not set. If no + * suffix is provided then seconds are assumed. + * + */ + def getTimeAsSeconds(key: String, defaultValue: String): Long = { + Utils.timeStringAsSeconds(get(key, defaultValue)) + } + + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsMs(key: String): Long = { + Utils.timeStringAsMs(get(key)) + } + + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. + */ + def getTimeAsMs(key: String, defaultValue: String): Long = { + Utils.timeStringAsMs(get(key, defaultValue)) + } + + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 0000000000000..3a2c94bd9d875 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 0000000000000..0075d963711f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 0000000000000..5fa4d483b8342 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + val dataStream = openDataStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def openDataStream(input: InputStream): Closeable + + protected def read(): U +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): (Int, Array[Byte]) = { + try { + val length = dataStream.readInt() + + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null // End of input + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): Array[Byte] = { + try { + val length = dataStream.readInt() + + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj, 0, length) + obj + case _ => null + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: BufferedReader = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new BufferedReader(new InputStreamReader(input)) + dataStream + } + + override protected def read(): String = { + try { + dataStream.readLine() + } catch { + case e: IOException => { + throw new SparkException("R worker exited unexpectedly (crashed)", e) + } + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 0000000000000..ccb2a371f4e48 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + val asciiBytes = new Array[Byte](len) + in.readFully(asciiBytes) + assert(asciiBytes(len - 1) == 0) + val str = new String(asciiBytes.dropRight(1).map(_.toChar)) + str + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 2c5da826f05df..8b627f5804bb0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -52,7 +52,7 @@ class LocalSparkCluster( /* Start the Master */ val (rpcEnv, _, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) masterRpcEnvs += rpcEnv - val masterUrl = "spark://" + localHostname + ":" + rpcEnv.address.port + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 0000000000000..e99779f299785 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index c2568eb4b60ac..cfaebf9ea5050 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -24,11 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -201,6 +200,37 @@ class SparkHadoopUtil extends Logging { val baseStatus = fs.getFileStatus(basePath) if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored + + /** + * Substitute variables by looking them up in Hadoop configs. Only variables that match the + * ${hadoopconf- .. } pattern are substituted. + */ + def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { + text match { + case HADOOP_CONF_PATTERN(matched) => { + logDebug(text + " matched " + HADOOP_CONF_PATTERN) + val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } + val eval = Option[String](hadoopConf.get(key)) + .map { value => + logDebug("Substituted " + matched + " with " + value) + text.replace(matched, value) + } + if (eval.isEmpty) { + // The variable was not found in Hadoop configs, so return text as is. + text + } else { + // Continue to substitute more variables. + substituteHadoopVariables(eval.get, hadoopConf) + } + } + case _ => { + logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) + text + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 660307d19eab4..60bc243ebf40a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -77,6 +77,7 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -284,6 +285,13 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => @@ -291,6 +299,9 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -317,11 +328,32 @@ object SparkSubmit { } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython && isYarnCluster) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } } // Special flag to avoid deprecation warnings at the client @@ -405,8 +437,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -447,6 +479,10 @@ object SparkSubmit { childArgs += ("--py-files", pyFilesNames) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -591,15 +627,15 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[deploy] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** @@ -619,12 +655,19 @@ object SparkSubmit { /** * Return whether the given primary resource requires running python. */ - private[deploy] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL + } + + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL } - private[deploy] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6eb73c43470a5..03ecf3fd99ec5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null @@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S opt } isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) false } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 9a8d5d5561238..1c79089303e3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -47,7 +47,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val rpcEnv = RpcEnv.create("spark", Utils.localIpAddress, 0, conf, new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 588016bd236b0..f5b59cfa077d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -50,7 +50,7 @@ private[deploy] class ExecutorRunner( val workerUrl: String, conf: SparkConf, val appLocalDirs: Seq[String], - var state: ExecutorState.Value) + @volatile var state: ExecutorState.Value) extends Logging { private val fullId = appId + "/" + execId diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 14f99a464b6e9..516f619529c48 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -436,14 +436,14 @@ private[spark] class Executor( * This thread stops running when the executor is stopped. */ private def startDriverHeartbeater(): Unit = { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") val thread = new Thread() { override def run() { // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + Thread.sleep(intervalMs + (math.random * intervalMs).asInstanceOf[Int]) while (!isStopped) { reportHeartBeat() - Thread.sleep(interval) + Thread.sleep(intervalMs) } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 04eb2bf9ba4ab..6b898bd4bfc1b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -181,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 741fe3e1ea750..8e3c30fc3d781 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -82,7 +82,8 @@ private[nio] class ConnectionManager( new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = - conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120)) + conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", + conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. // diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac069..3406a7e97e368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc2..71578d1210fde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9059eb13bb5d8..c1d6971787572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 760c0fa3ac96a..0d130dd4c7a60 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -49,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f6..7021a339e879b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -99,7 +99,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it @@ -120,7 +120,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = split.deps.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 5117ccfabfcc2..0c1b02c07d09f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 71e6e300fec5f..29ca3e9c4bd04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6fdfdb734d1b8..6afe50161dacd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index bf1303d39592d..05351ba4ff76b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ddbfd5624e741..d80d94a588346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -316,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -488,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -852,7 +852,7 @@ abstract class RDD[T: ClassTag]( * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -986,14 +986,14 @@ abstract class RDD[T: ClassTag]( combOp: (U, U) => U, depth: Int = 2): U = { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { + if (partitions.length == 0) { return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) } val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size + var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) { @@ -1026,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1061,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1140,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1243,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1489,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1499,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..6afd63d537d75 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -94,10 +94,10 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) new SerializableWritable(rdd.context.hadoopConfiguration)) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index c27f435eb9c5a..e9d745588ee9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,7 +76,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4239e7e22af89..3986645350a82 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index d0be304762e1f..a96b6c3d23454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f2..523aaf2b860b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 917cce1f9686c..508fe7b3303ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal @@ -49,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -110,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -127,8 +134,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val outputCommitCoordinator = env.outputCommitCoordinator - // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -640,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => @@ -709,9 +714,10 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } @@ -886,10 +892,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - outputCommitCoordinator.stageEnd(stage.id) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) val debugString = stage match { case stage: ShuffleMapStage => @@ -901,7 +906,6 @@ class DAGScheduler( s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" } logDebug(debugString) - runningStages -= stage } } @@ -967,22 +971,6 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTimeMillis()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -1098,7 +1086,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1214,6 +1201,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1263,8 +1270,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 9e29fd13821dc..7c184b1dcb308 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -59,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + /** * Called by tasks to ask whether they can commit their output to HDFS. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fd0d484b45460..6c7d00069acb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,7 +33,7 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4d9f940813b8e..8b592867ee31d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 076b36e86c0ce..2362cc7240039 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -62,10 +62,10 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) + val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) + val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") // CPUs to request per task val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) @@ -143,8 +143,8 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, + SPECULATION_INTERVAL_MS milliseconds) { Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } } } @@ -173,7 +173,7 @@ private[spark] class TaskSchedulerImpl( this.cancel() } } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) } hasReceivedTask = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d509881c74fef..7dc325283d961 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -848,15 +848,18 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - conf.get("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - conf.get("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - conf.get("spark.locality.wait.rack", defaultWait).toLong - case _ => 0L + val defaultWait = conf.get("spark.locality.wait", "3s") + val localityWaitKey = level match { + case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" + case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" + case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" + case _ => null + } + + if (localityWaitKey != null) { + conf.getTimeAsMs(localityWaitKey, defaultWait) + } else { + 0L } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4c49da87af9dc..63987dfb32695 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -52,8 +52,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) + val maxRegisteredWaitingTimeMs = + conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") val createTime = System.currentTimeMillis() private val executorDataMap = new HashMap[String, ExecutorData] @@ -77,12 +77,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) + val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) } - }, 0, reviveInterval, TimeUnit.MILLISECONDS) + }, 0, reviveIntervalMs, TimeUnit.MILLISECONDS) } override def receive: PartialFunction[Any, Unit] = { @@ -301,9 +302,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } - if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { + if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTimeMs) { logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTimeMs(ms)") return true } false diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 0186eb30a1905..034525b56f59c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 67f572e79314d..77c0bc8b5360a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -65,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -81,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index ea548f23120d9..f9860d1a5ce76 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -48,7 +48,7 @@ private[spark] abstract class WebUI( protected val handlers = ArrayBuffer[ServletContextHandler]() protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostName() + protected val localHostName = Utils.localHostNameForURI() protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 6c2c5261306e7..8e8cc7cc6389e 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import scala.util.Try import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -66,7 +65,8 @@ private[spark] object AkkaUtils extends Logging { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 120)) + val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", + conf.get("spark.network.timeout", "120s")) val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" @@ -78,8 +78,8 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) - val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s") + val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s") val secretKey = securityManager.getSecretKey() val isAuthOn = securityManager.isAuthenticationEnabled() @@ -102,14 +102,14 @@ private[spark] object AkkaUtils extends Logging { |akka.jvm-exit-on-fatal-error = off |akka.remote.require-cookie = "$requireCookie" |akka.remote.secure-cookie = "$secureCookie" - |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s - |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s + |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s + |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s |akka.actor.provider = "akka.remote.RemoteActorRefProvider" |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" |akka.remote.netty.tcp.hostname = "$host" |akka.remote.netty.tcp.port = $port |akka.remote.netty.tcp.tcp-nodelay = on - |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s + |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 375ed430bde45..2bbfc988a99a8 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -76,7 +76,7 @@ private[spark] object MetadataCleanerType extends Enumeration { // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { def getDelaySeconds(conf: SparkConf): Int = { - conf.getInt("spark.cleaner.ttl", -1) + conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt } def getDelaySeconds( diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 77aa49d95220a..8f38ac7f3f8ab 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -34,6 +34,7 @@ import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} +import com.google.common.net.InetAddresses import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -46,6 +47,7 @@ import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -611,9 +613,10 @@ private[spark] object Utils extends Logging { } Utils.setupSecureURLConnection(uc, securityMgr) - val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 - uc.setConnectTimeout(timeout) - uc.setReadTimeout(timeout) + val timeoutMs = + conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 + uc.setConnectTimeout(timeoutMs) + uc.setReadTimeout(timeoutMs) uc.connect() val in = uc.getInputStream() downloadFile(url, in, targetFile, fileOverwrite) @@ -789,13 +792,12 @@ private[spark] object Utils extends Logging { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). * Note, this is typically not used from within core spark. */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + private lazy val localIpAddress: InetAddress = findLocalInetAddress() - private def findLocalIpAddress(): String = { + private def findLocalInetAddress(): InetAddress = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") if (defaultIpOverride != null) { - defaultIpOverride + InetAddress.getByName(defaultIpOverride) } else { val address = InetAddress.getLocalHost if (address.isLoopbackAddress) { @@ -806,15 +808,20 @@ private[spark] object Utils extends Logging { // It's more proper to pick ip address following system output order. val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + val addresses = ni.getInetAddresses.toList + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") + " a loopback address: " + address.getHostAddress + "; using " + + strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress + return strippedAddress } } logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + @@ -822,7 +829,7 @@ private[spark] object Utils extends Logging { " external IP address!") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") } - address.getHostAddress + address } } @@ -842,11 +849,14 @@ private[spark] object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) + customHostname.getOrElse(localIpAddress.getHostAddress) } - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName + /** + * Get the local machine's URI. + */ + def localHostNameForURI(): String = { + customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } def checkHost(host: String, message: String = "") { @@ -1028,6 +1038,22 @@ private[spark] object Utils extends Logging { ) } + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + def timeStringAsMs(str: String): Long = { + JavaUtils.timeStringAsMs(str) + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + def timeStringAsSeconds(str: String): Long = { + JavaUtils.timeStringAsSec(str) + } + /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index f79e8e0491ea1..41cb8cfe2afa3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index bd0f8bdefa171..75399461f2a5f 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -27,19 +27,20 @@ import org.scalatest.Matchers class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { - implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { - def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[A]) : mutable.Set[A] = { - new mutable.HashSet[A]() + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = + new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } } - } test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -49,11 +50,10 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { d.foreach{x => acc += x} acc.value should be (210) - - val longAcc = sc.accumulator(0l) + val longAcc = sc.accumulator(0L) val maxInt = Integer.MAX_VALUE.toLong d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) + longAcc.value should be (210L + maxInt * 20) } test ("value not assignable from tasks") { diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 4b25c200a695a..70529d9216591 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -45,16 +45,17 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf rdd = new RDD[Int](sc, Nil) { override def getPartitions: Array[Partition] = Array(split) override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator + override def compute(split: Partition, context: TaskContext): Iterator[Int] = + Array(1, 2, 3, 4).iterator } rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 32abc65385267..e1faddeabec79 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -75,7 +75,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.partitions.toList === + parCollection.checkpointData.get.getPartitions.toList) assert(parCollection.collect() === result) } @@ -102,13 +103,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(_.union(otherRDD)) testRDDPartitions(_.union(otherRDD)) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(new CartesianRDD(sc, _, otherRDD)) testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) @@ -223,7 +224,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( - partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass, + partitionBeforeCheckpoint.parents.head.getClass != + partitionAfterCheckpoint.parents.head.getClass, "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" ) } @@ -358,7 +360,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD() = { + def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -445,7 +447,8 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index cdfaacee7da40..1de169d964d23 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -64,7 +64,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha } } - //------ Helper functions ------ + // ------ Helper functions ------ protected def newRDD() = sc.makeRDD(1 to 10) protected def newPairRDD() = newRDD().map(_ -> 1) @@ -370,7 +370,7 @@ class CleanerTester( val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { toBeCleanedRDDIds -= rddId - logInfo("RDD "+ rddId + " cleaned") + logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 3ded1e4af8742..6b3049b28cd5e 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -684,10 +684,11 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) - .set("spark.dynamicAllocation.schedulerBacklogTimeout", schedulerBacklogTimeout.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", + s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", - sustainedSchedulerBacklogTimeout.toString) - .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) + s"${sustainedSchedulerBacklogTimeout.toString}s") + .set("spark.dynamicAllocation.executorIdleTimeout", s"${executorIdleTimeout.toString}s") .set("spark.dynamicAllocation.testing", "true") val sc = new SparkContext(conf) contexts += sc diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 7acd27c735727..c8f08eed47c76 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -222,7 +222,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) nums.saveAsSequenceFile(outputDir) val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } @@ -451,7 +451,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) } @@ -459,8 +460,10 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the non-empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) @@ -471,16 +474,20 @@ class FileSuite extends FunSuite with LocalSparkContext { val sf = new SparkConf() sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") sc = new SparkContext(sf) - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) } test ("save Hadoop Dataset through old Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new JobConf() job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) @@ -492,7 +499,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("save Hadoop Dataset through new Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new Job(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index d895230ecf330..51348c039b5c9 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -51,7 +51,7 @@ private object ImplicitOrderingSuite { override def compare(o: OrderedClass): Int = ??? } - def basicMapExpectations(rdd: RDD[Int]) = { + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), (rdd.map(x => (1, x)).keyOrdering.isDefined, @@ -68,7 +68,7 @@ private object ImplicitOrderingSuite { "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - def otherRDDMethodExpectations(rdd: RDD[Int]) = { + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, @@ -82,4 +82,4 @@ private object ImplicitOrderingSuite { (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined, "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined")) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 21487bc24d58a..4d3e09793faff 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -188,7 +188,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter val rdd = sc.parallelize(1 to 10, 2).map { i => JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) - }.reduceByKey(_+_) + }.reduceByKey(_ + _) val f1 = rdd.collectAsync() val f2 = rdd.countAsync() diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 53e367a61715b..8bf2e55defd02 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -37,7 +37,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self super.afterEach() } - def resetSparkContext() = { + def resetSparkContext(): Unit = { LocalSparkContext.stop(sc) sc = null } @@ -54,7 +54,7 @@ object LocalSparkContext { } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { try { f(sc) } finally { diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index b7532314ada01..47e3bf6e1ac41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -92,7 +92,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row) = x.value - y.value + override def compare(x: Row, y: Row): Int = x.value - y.value } val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) @@ -212,20 +212,24 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + def verify(testFun: => Unit): Unit = { + intercept[SparkException](testFun).getMessage.contains("array") + } + + verify(arrs.distinct()) // We can't catch all usages of arrays, since they might occur inside other collections: // assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + verify(arrPairs.partitionBy(new HashPartitioner(2))) + verify(arrPairs.join(arrPairs)) + verify(arrPairs.leftOuterJoin(arrPairs)) + verify(arrPairs.rightOuterJoin(arrPairs)) + verify(arrPairs.fullOuterJoin(arrPairs)) + verify(arrPairs.groupByKey()) + verify(arrPairs.countByKey()) + verify(arrPairs.countByKeyApprox(1)) + verify(arrPairs.cogroup(arrPairs)) + verify(arrPairs.reduceByKeyLocally(_ + _)) + verify(arrPairs.reduceByKey(_ + _)) } test("zero-length partitions should be correctly handled") { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 444a33371bd71..93f46ef11c0e2 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -36,7 +36,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,7 +53,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test resolving property with defaults specified ") { @@ -66,7 +68,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) @@ -83,7 +86,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test whether defaults can be overridden ") { @@ -99,7 +103,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index ace8123a8961f..308b9ea17708d 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -21,10 +21,11 @@ import java.io.File object SSLSampleConfigs { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File(this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath + val untrustedKeyStorePath = new File( + this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - def sparkSSLConfig() = { + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -38,7 +39,7 @@ object SSLSampleConfigs { conf } - def sparkSSLConfigUntrusted() = { + def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index f57921b768310..d7180516029d5 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -142,7 +142,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new ShuffledRDD[Int, Int, Int](pairs, @@ -155,7 +155,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) @@ -169,7 +169,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -196,7 +196,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex shuffleSpillCompress <- Set(true, false); shuffleCompress <- Set(true, false) ) { - val conf = new SparkConf() + val myConf = conf.clone() .setAppName("test") .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() - sc = new SparkContext(conf) + sc = new SparkContext(myConf) try { sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() } catch { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index c7301a30d8b11..94be1c6d6397c 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -114,11 +114,13 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { if (length1 != gotten1.length()) { throw new SparkException( - s"file has different length $length1 than added file ${gotten1.length()} : " + absolutePath1) + s"file has different length $length1 than added file ${gotten1.length()} : " + + absolutePath1) } if (length2 != gotten2.length()) { throw new SparkException( - s"file has different length $length2 than added file ${gotten2.length()} : " + absolutePath2) + s"file has different length $length2 than added file ${gotten2.length()} : " + + absolutePath2) } if (absolutePath1 == gotten1.getAbsolutePath) { diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 41d6ea29d5b06..084eb237d70d1 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -82,7 +82,8 @@ class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be ( + Set(firstJobId, secondJobId)) } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index af3272692d7a1..c8fdfa693912e 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -33,7 +33,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { val broadcast = rdd.context.broadcast(list) val bid = broadcast.id - def doSomething() = { + def doSomething(): Set[(Int, Boolean)] = { rdd.map { x => val bm = SparkEnv.get.blockManager // Check if broadcast block was fetched diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 518073dcbb64e..745f9eeee7536 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -46,5 +46,4 @@ class ClientSuite extends FunSuite with Matchers { // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } - } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 68b5776fc6515..2071701b313db 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -100,13 +100,13 @@ class JsonProtocolSuite extends FunSuite { appInfo } - def createDriverCommand() = new Command( + def createDriverCommand(): Command = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) - def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, - false, createDriverCommand()) + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 54dd7c9c45c61..9cdb42814ca32 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -56,7 +56,7 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3a9963a5ce7b7..20de46fdab909 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -42,10 +42,10 @@ class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar { when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) val page = new HistoryPage(historyServer) - //when + // when val response = page.render(request) - //then + // then val links = response \\ "a" val justHrefs = for { l <- links diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 1d64ec201e647..61071ee17256c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -129,7 +129,8 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.sparkProperties("spark.files") === "fireball.png") assert(newMessage.sparkProperties("spark.driver.memory") === "512m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") - assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === + " -Dslices=5 -Dcolor=mostly_red") assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") assert(newMessage.sparkProperties("spark.driver.supervise") === "false") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6fca6321e5a1b..a8b9df227c996 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -35,7 +35,8 @@ class ExecutorRunnerTest extends FunSuite { val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) + val builder = CommandUtils.buildProcessBuilder( + appDesc.command, 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 372d7aa453008..7cc2104281464 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -37,7 +37,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "50000" else super.getenv(name) } @@ -56,7 +56,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "5G" else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 84e2fd7ad936d..450fba21f4b5c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.{Matchers, FunSuite} class WorkerSuite extends FunSuite with Matchers { - def cmd(javaOpts: String*) = Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) - def conf(opts: (String, String)*) = new SparkConf(loadDefaults = false).setAll(opts) + def cmd(javaOpts: String*): Command = { + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + } + def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 78fa98a3b9065..190b08d950a02 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -238,7 +238,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext sc.textFile(tmpFilePath, 4) .map(key => (key, 1)) - .reduceByKey(_+_) + .reduceByKey(_ + _) .saveAsTextFile("file://" + tmpFile.getAbsolutePath) sc.listenerBus.waitUntilEmpty(500) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 37e528435aa5d..100ac77dec1f7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -35,7 +35,8 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { val property = conf.getInstance("random") assert(property.size() === 2) - assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(property.getProperty("sink.servlet.path") === "/metrics/json") } @@ -47,16 +48,20 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(masterProp.size() === 5) assert(masterProp.getProperty("sink.console.period") === "20") assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") val workerProp = conf.getInstance("worker") assert(workerProp.size() === 5) assert(workerProp.getProperty("sink.console.period") === "10") assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 716f875d30b8a..02424c59d6831 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -260,8 +260,8 @@ class ConnectionManagerSuite extends FunSuite { test("sendMessageReliably timeout") { val clientConf = new SparkConf clientConf.set("spark.authenticate", "false") - val ackTimeout = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + val ackTimeoutS = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") val clientSecurityManager = new SecurityManager(clientConf) val manager = new ConnectionManager(0, clientConf, clientSecurityManager) @@ -272,7 +272,7 @@ class ConnectionManagerSuite extends FunSuite { val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeout * 3 * 1000) + Thread.sleep(ackTimeoutS * 3 * 1000) None }) @@ -287,7 +287,7 @@ class ConnectionManagerSuite extends FunSuite { // Otherwise TimeoutExcepton is thrown from Await.result. // We expect TimeoutException is not thrown. intercept[IOException] { - Await.result(future, (ackTimeout * 2) second) + Await.result(future, (ackTimeoutS * 2) second) } manager.stop() diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 0dc59888f7304..be8467354b222 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -80,7 +80,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) + assert(rdd.reduce(_ + _) === 10100) } test("large id overflow") { @@ -92,7 +92,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { 1131544775L, 567279358897692673L, 20, (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 5050) + assert(rdd.reduce(_ + _) === 5050) } after { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 108f70af43f37..ca0d953d306d8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,13 +168,13 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() + val sums = pairs.reduceByKey(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } test("reduceByKey with collectAsMap") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() + val sums = pairs.reduceByKey(_ + _).collectAsMap() assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) @@ -182,7 +182,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey with many output partitons") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() + val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -192,7 +192,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def getPartition(key: Any) = key.asInstanceOf[Int] } val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) + val sums = pairs.reduceByKey(_ + _) assert(sums.collect().toSet === Set((1, 4), (0, 1))) assert(sums.partitioner === Some(p)) // count the dependencies to make sure there is only 1 ShuffledRDD @@ -208,7 +208,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("countApproxDistinctByKey") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is * only a statistical bound, the tests can fail for large values of relativeSD. We will be using @@ -465,7 +465,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("foldByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() + val sums = pairs.foldByKey(0)(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -505,7 +505,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { conf.setOutputCommitter(classOf[FakeOutputCommitter]) FakeOutputCommitter.ran = false - pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } @@ -552,7 +553,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } private object StratifiedAuxiliary { - def stratifier (fractionPositive: Double) = { + def stratifier (fractionPositive: Double): (Int) => String = { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } @@ -572,7 +573,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, true, samplingRate, seed, n) testPoisson(stratifiedData, true, samplingRate, seed, n) } @@ -580,7 +581,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSample(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, false, samplingRate, seed, n) testPoisson(stratifiedData, false, samplingRate, seed, n) } @@ -590,7 +591,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey() .mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -612,7 +613,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -701,27 +702,27 @@ class FakeOutputFormat() extends OutputFormat[Integer, Integer]() { */ class NewFakeWriter extends NewRecordWriter[Integer, Integer] { - def close(p1: NewTaskAttempContext) = () + def close(p1: NewTaskAttempContext): Unit = () - def write(p1: Integer, p2: Integer) = () + def write(p1: Integer, p2: Integer): Unit = () } class NewFakeCommitter extends NewOutputCommitter { - def setupJob(p1: NewJobContext) = () + def setupJob(p1: NewJobContext): Unit = () def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false - def setupTask(p1: NewTaskAttempContext) = () + def setupTask(p1: NewTaskAttempContext): Unit = () - def commitTask(p1: NewTaskAttempContext) = () + def commitTask(p1: NewTaskAttempContext): Unit = () - def abortTask(p1: NewTaskAttempContext) = () + def abortTask(p1: NewTaskAttempContext): Unit = () } class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { - def checkOutputSpecs(p1: NewJobContext) = () + def checkOutputSpecs(p1: NewJobContext): Unit = () def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { new NewFakeWriter() @@ -735,7 +736,7 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false - def setConf(p1: Configuration) = { + def setConf(p1: Configuration): Unit = { setConfCalled = true () } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index cd193ae4f5238..1880364581c1a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -100,7 +100,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[Range])) } @@ -108,7 +108,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[Range])) } @@ -139,7 +139,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(i).isInstanceOf[Range]) val range = slices(i).asInstanceOf[Range] assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.end === (i + 1) * (N / 40), "slice " + i + " end") assert(range.step === 1, "slice " + i + " step") } } @@ -156,7 +156,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size /n + 1)) } check(prop) } @@ -174,7 +174,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -192,7 +192,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -209,7 +209,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -217,7 +217,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -225,7 +225,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 8408d7e785c65..465068c6cbb16 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.{Partition, SharedSparkContext, TaskContext} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - test("Pruned Partitions inherit locality prefs correctly") { val rdd = new RDD[Int](sc, Nil) { @@ -74,8 +73,6 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { } class TestPartition(i: Int, value: Int) extends Partition with Serializable { - def index = i - - def testValue = this.value - + def index: Int = i + def testValue: Int = this.value } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index a0483886f8db3..0d1369c19c69e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -35,7 +35,7 @@ class MockSampler extends RandomSampler[Long, Long] { Iterator(s) } - override def clone = new MockSampler + override def clone: MockSampler = new MockSampler } class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index bede1ffb3e2d0..df42faab64505 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -82,7 +82,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("countApproxDistinct") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble val size = 1000 val uniformDistro = for (i <- 1 to 5000) yield i % size @@ -100,7 +100,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("partitioner aware union") { - def makeRDDWithPartitioner(seq: Seq[Int]) = { + def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) .map(x => (x, null)) .partitionBy(new HashPartitioner(2)) @@ -159,8 +159,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("treeAggregate") { val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 + def seqOp: (Long, Int) => Long = (c: Long, x: Int) => c + x + def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) assert(sum === -1000L) @@ -204,7 +204,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(empty.collect().size === 0) val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) + empty.reduce(_ + _) } assert(thrown.getMessage.contains("empty")) @@ -321,7 +321,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) val coalesced1 = data.coalesce(3) assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") @@ -921,15 +921,17 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("task serialization exception should not hang scheduler") { class BadSerializable extends Serializable { @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + private def writeObject(out: ObjectOutputStream): Unit = + throw new KryoException("Bad serialization") @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = {} } - // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were - // more threads in the Spark Context than there were number of objects in this sequence. + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if + // there were more threads in the Spark Context than there were number of objects in this + // sequence. intercept[Throwable] { - sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect() } // Check that the context has not crashed sc.parallelize(1 to 100).map(x => x*2).collect diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index 4762fc17855ce..fe695d85e29dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = a.age compare b.age + def compare(a:Person, b:Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = + def compare(a:Person, b:Person): Int = implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5a734ec5ba5ec..ada07ef11cd7a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -70,7 +70,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("send-remotely", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case msg: String => message = msg } }) @@ -109,7 +109,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { context.reply(msg) } @@ -123,7 +123,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { context.reply(msg) } @@ -146,7 +146,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-timeout", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { Thread.sleep(100) context.reply(msg) @@ -182,7 +182,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { calledMethods += "start" } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case msg: String => } @@ -206,7 +206,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { throw new RuntimeException("Oops!") } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -225,7 +225,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -250,8 +250,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { override val rpcEnv = env - override def receive = { - case m => throw new RuntimeException("Oops!") + override def receive: PartialFunction[Any, Unit] = { + case m => throw new RuntimeException("Oops!") } override def onError(cause: Throwable): Unit = { @@ -277,7 +277,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { callSelfSuccessfully = true } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } }) @@ -294,7 +294,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => { self callSelfSuccessfully = true @@ -316,7 +316,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -343,7 +343,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => result += 1 } @@ -372,7 +372,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -394,7 +394,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.reply("ack") } }) @@ -410,7 +410,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.reply("ack") } }) @@ -432,7 +432,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.sendFailure(new SparkException("Oops")) } }) @@ -450,7 +450,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => context.sendFailure(new SparkException("Oops")) } }) @@ -476,7 +476,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case "hello" => case m => events += "receive" -> m } @@ -519,7 +519,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => context.sendFailure(new UnserializableException) } }) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 63360a0f189a3..3c52a8c4460c6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -57,20 +57,18 @@ class MyRDD( locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") - override def getPartitions = (0 until numPartitions).map(i => new Partition { - override def index = i + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i }).toArray override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil + if (locations.isDefinedAt(split.index)) locations(split.index) else Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite + extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -209,7 +207,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -269,21 +268,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(new MyRDD(sc, 1, Nil), Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job") { val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" + override def getPartitions: Array[Partition] = + Array( new Partition { override def index: Int = 0 } ) + override def getPreferredLocations(split: Partition): List[String] = Nil + override def toString: String = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job oom") { @@ -295,9 +296,10 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial job w/ dependency") { @@ -306,7 +308,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cache location preferences w/ dependency") { @@ -319,7 +321,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("regression test for getCacheLocs") { @@ -335,7 +337,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("avoid exponential blowup when getting preferred locs list") { - // Build up a complex dependency graph with repeated zip operations, without preferred locations. + // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. @@ -357,7 +359,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job failure") { @@ -367,7 +369,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job cancellation") { @@ -378,7 +380,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("job cancellation no-kill backend") { @@ -387,18 +389,20 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar val noKillTaskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { + override def start(): Unit = {} + override def stop(): Unit = {} + override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} + override def defaultParallelism(): Int = 2 + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} } val noKillScheduler = new DAGScheduler( @@ -422,7 +426,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // When the task set completes normally, state should be correctly updated. complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.isEmpty) @@ -442,7 +446,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with fetch failure") { @@ -465,10 +469,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial shuffle with multiple fetch failures") { @@ -521,19 +526,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run shuffle with map stage failure") { @@ -552,7 +561,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.toSet === Set(0)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -586,7 +595,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar class FailureRecordingJobListener() extends JobListener { var failureMessage: String = _ override def taskSucceeded(index: Int, result: Any) {} - override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage } } val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() @@ -606,7 +615,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with out-of-band failure and retry") { @@ -629,7 +638,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("recursive shuffle failures") { @@ -658,7 +667,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cached post-shuffle") { @@ -690,7 +699,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { @@ -742,7 +751,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("accumulator not calculated for resubmitted result stage") { - //just for register + // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) val finalRdd = new MyRDD(sc, 1, Nil) submit(finalRdd, Array(0)) @@ -754,7 +763,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(accVal === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -774,7 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - private def assertDataStructuresEmpty = { + private def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -783,6 +792,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } // Nothing in this test should break if the task info's fields are null, but diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 30ee63e78d9d8..6d25edb7d20dc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -268,7 +268,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef object EventLoggingListenerSuite { /** Get a SparkConf with event logging enabled. */ - def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None) = { + def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") @@ -280,5 +280,5 @@ object EventLoggingListenerSuite { conf } - def getUniqueApplicationId = "test-" + System.currentTimeMillis + def getUniqueApplicationId: String = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 6b75c98839e03..9b92f8de56759 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -24,7 +24,9 @@ import org.apache.spark.TaskContext /** * A Task implementation that fails to serialize. */ -private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) + extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 627c9a4ddfffc..825c616c0c3e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -85,7 +85,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val stopperReturned = new Semaphore(0) class BlockingListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() drained = true @@ -206,8 +206,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc.addSparkListener(new StatsReportListener) // just to make sure some of the tasks take a noticeable amount of time val w = { i: Int => - if (i == 0) + if (i == 0) { Thread.sleep(100) + } i } @@ -247,12 +248,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) + taskMetrics.resultSize should be > (0L) if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) @@ -260,7 +261,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sm.totalBlocksFetched should be (128) sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) + sm.remoteBytesRead should be (0L) } } } @@ -406,12 +407,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val startedGettingResultTasks = new mutable.HashSet[Int]() val endedTasks = new mutable.HashSet[Int]() - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { startedTasks += taskStart.taskInfo.index notify() } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { endedTasks += taskEnd.taskInfo.index notify() } @@ -425,7 +426,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers * A simple listener that throws an exception on job end. */ private class BadListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { throw new Exception } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } } @@ -438,10 +439,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ private class BasicJobCounter extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index add13f5b21765..ffa4381969b68 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import java.util.Properties - import org.scalatest.FunSuite import org.apache.spark._ @@ -27,7 +25,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def start() {} def stop() {} def reviveOffers() {} - def defaultParallelism() = 1 + def defaultParallelism(): Int = 1 } class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { @@ -115,7 +113,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } val numFreeCores = 1 taskScheduler.setDAGScheduler(dagScheduler) - var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -123,7 +122,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(0 === taskDescriptions.length) // Now check that we can still submit tasks - // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + // Even if one of the tasks has not-serializable tasks, the other task set should + // still be processed without error taskScheduler.submitTasks(taskSet) taskScheduler.submitTasks(FakeTask.createTaskSet(1)) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 12330d8f63c40..6198cea46ddf8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} import java.util.Random import scala.collection.mutable.ArrayBuffer @@ -27,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -67,7 +66,7 @@ object FakeRackUtil { hostToRack(host) = rack } - def getRackForHost(host: String) = { + def getRackForHost(host: String): Option[String] = { hostToRack.get(host) } } @@ -152,7 +151,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { private val conf = new SparkConf - val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) + val LOCALITY_WAIT_MS = conf.getTimeAsMs("spark.locality.wait", "3s") val MAX_TASK_FAILURES = 4 override def beforeEach() { @@ -240,7 +239,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) @@ -251,7 +250,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) } @@ -292,7 +291,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host1 again: nothing should get chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1 again: second task (on host2) should get chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -306,7 +305,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Now that we've launched a local task, we should no longer launch the task for host3 assert(manager.resourceOffer("exec2", "host2", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // After another delay, we can go ahead and launch that task non-locally assert(manager.resourceOffer("exec2", "host2", ANY).get.index === 3) @@ -327,8 +326,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // After this, nothing should get chosen, because we have separated tasks with unavailable preference - // from the noPrefPendingTasks + // After this, nothing should get chosen, because we have separated tasks with unavailable + // preference from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead @@ -338,7 +337,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT * 2) + clock.advance(LOCALITY_WAIT_MS * 2) // task 1 and 2 would be scheduled as nonLocal task assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -499,7 +498,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sched.addExecutor("execC", "host2") manager.executorAdded() // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") manager.executorLost("execC", "host2") @@ -527,7 +527,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) // Set allowed locality to ANY - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // Offer host3 // No task is scheduled if we restrict locality to RACK_LOCAL assert(manager.resourceOffer("execC", "host3", RACK_LOCAL) === None) @@ -569,7 +569,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) intercept[TaskNotSerializableException] { @@ -582,7 +583,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) - def genBytes(size: Int) = { (x: Int) => + def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => val bytes = Array.ofDim[Byte](size) scala.util.Random.nextBytes(bytes) bytes @@ -605,7 +606,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2"), TaskLocation("host1")), @@ -619,19 +621,21 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1) manager.speculatableTasks += 1 - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // schedule the nonPref task assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2) // schedule the speculative task assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1) - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // schedule non-local tasks assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) } - test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + test("node-local tasks should be scheduled right away " + + "when there are only node-local and no-preference tasks") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -650,7 +654,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } - test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") { + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") + { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) val taskSet = FakeTask.createTaskSet(4, @@ -710,13 +715,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) assert(manager.resourceOffer("execA", "host1", ANY) !== None) - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") manager.executorLost("execA", "host1") manager.executorLost("execB.2", "host2") - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index f1a4380d349b3..a311512e82c5e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -63,16 +63,18 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo // uri is null. val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo.getCommand.getValue === + s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo1.getCommand.getValue === + s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } test("mesos resource offers result in launching tasks") { - def createOffer(id: Int, mem: Int, cpu: Int) = { + def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -82,8 +84,10 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() } val driver = mock[SchedulerDriver] diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 6198df84fab3d..b070a54aa989b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -106,7 +106,9 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one",2->"two",3->"three"))) } test("ranges") { @@ -169,7 +171,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("kryo with collect") { val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + val result = sc.parallelize(control, 2) + .map(new ClassWithoutNoArgConstructor(_)) + .collect() + .map(_.x) assert(control === result.toSeq) } @@ -237,7 +242,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { // Set a special, broken ClassLoader and make sure we get an exception on deserialization ser.setDefaultClassLoader(new ClassLoader() { - override def loadClass(name: String) = throw new UnsupportedOperationException + override def loadClass(name: String): Class[_] = throw new UnsupportedOperationException }) intercept[UnsupportedOperationException] { ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) @@ -287,14 +292,14 @@ object KryoTest { class ClassWithNoArgConstructor { var x: Int = 0 - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithNoArgConstructor => x == c.x case _ => false } } class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithoutNoArgConstructor => x == c.x case _ => false } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index d037e2c19a64d..433fd6bb4a11d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { - def op[T](x: T) = x.toString + def op[T](x: T): String = x.toString - def pred[T](x: T) = x.toString.length % 2 == 0 + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { - def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + def fixture: (RDD[String], UnserializableClass) = { + (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + } test("throws expected serialization exceptions on actions") { val (data, uc) = fixture diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 0ade1bab18d7e..963264cef3a71 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag * A serializer implementation that always return a single element in a deserialization stream. */ class TestSerializer extends Serializer { - override def newInstance() = new TestSerializerInstance + override def newInstance(): TestSerializerInstance = new TestSerializerInstance } @@ -36,7 +36,8 @@ class TestSerializerInstance extends SerializerInstance { override def serializeStream(s: OutputStream): SerializationStream = ??? - override def deserializeStream(s: InputStream) = new TestDeserializationStream + override def deserializeStream(s: InputStream): TestDeserializationStream = + new TestDeserializationStream override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index b834dc0e735eb..7d76435cd75e7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -85,8 +85,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { // Now comes the test : // Write to shuffle 3; and close it, but before registering it, check if the file lengths for // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is concurrent read - // and writes happening to the same shuffle group. + // of block based on remaining data in file : which could mess things up when there is + // concurrent read and writes happening to the same shuffle group. val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index b4de90b65d545..ffa5162a31841 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -76,7 +76,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd conf.set("spark.storage.unrollMemoryThreshold", "512") // to make a replication attempt to inactive store fail fast - conf.set("spark.core.connection.ack.wait.timeout", "1") + conf.set("spark.core.connection.ack.wait.timeout", "1s") // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6dc5bc4cb08c4..545722b050ee8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -60,7 +60,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) private def makeBlockManager( maxMem: Long, @@ -107,8 +107,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, false, 3) - val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object + // this should return the same object as level1 + val level2 = StorageLevel(false, false, false, false, 3) + // this should return a different object + val level3 = StorageLevel(false, false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -802,7 +804,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Create a non-trivial (not all zeros) byte array var counter = 0.toByte - def incr = {counter = (counter + 1).toByte; counter;} + def incr: Byte = {counter = (counter + 1).toByte; counter;} val bytes = Array.fill[Byte](1000)(incr) val byteBuffer = ByteBuffer.wrap(bytes) @@ -956,8 +958,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size + === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size + === 1) // insert some more blocks store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -965,8 +969,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size + === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size + === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => @@ -1090,8 +1096,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // Unroll with plenty of space. This should succeed and cache both blocks. @@ -1144,8 +1150,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val diskStore = store.diskStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) store.putIterator("b1", smallIterator, memAndDisk) @@ -1187,7 +1193,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memOnly = StorageLevel.MEMORY_ONLY val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // All unroll memory used is released because unrollSafely returned an array diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index 82a82e23eecf2..b47157f8331cc 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -47,7 +47,7 @@ class LocalDirsSuite extends FunSuite with BeforeAndAfter { assert(!new File("/NONEXISTENT_DIR").exists()) // SPARK_LOCAL_DIRS is a valid directory: class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 0d155982a8c54..1cb594633f331 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -137,7 +137,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before new SparkContext(conf) } - def hasKillLink = find(className("kill-link")).isDefined + def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index c0c28cb60e21d..21d8267114133 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -269,7 +269,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) val execId = "exe-1" - def makeTaskMetrics(base: Int) = { + def makeTaskMetrics(base: Int): TaskMetrics = { val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() @@ -291,7 +291,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics } - def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = finishTime diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index e1bc1379b5d80..3744e479d2f05 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -107,7 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + val stageInfo0 = new StageInfo( + 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 054ef54e746a5..c47162779bbba 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -83,7 +83,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -95,7 +95,7 @@ class TestClass extends Serializable { } class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -164,7 +164,7 @@ object TestObjectWithNesting { } class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y + def getY: Int = y def run(): Int = { var nonSer = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 43b6a405cb68c..c05317534cddf 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -109,7 +109,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // verify whether the earliest file has been deleted val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted - logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + + rolledOverFiles.mkString("\n")) assert(rolledOverFiles.size > 2) val earliestRolledOverFile = rolledOverFiles.head val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( @@ -135,7 +136,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - //assert(appender.getClass === classTag[ExpectedAppender].getClass) + // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { @@ -153,9 +154,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { import RollingFileAppender._ - def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) - def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) - def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + def rollingStrategy(strategy: String): Seq[(String, String)] = + Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String): Seq[(String, String)] = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String): Seq[(String, String)] = + Seq(INTERVAL_PROPERTY -> interval) val msInDay = 24 * 60 * 60 * 1000L val msInHour = 60 * 60 * 1000L diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 72e81f3f1a884..403dcb03bd6e5 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -71,7 +71,7 @@ class NextIteratorSuite extends FunSuite with Matchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - override def getNext() = { + override def getNext(): Int = { if (ints.size == 0) { finished = true 0 diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 7424c2e91d4f2..67a9f75ff2187 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -98,8 +98,10 @@ class SizeEstimatorSuite // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object - assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object + // 10 pointers plus 8-byte object + assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) + // 100 pointers plus 8-byte object + assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index c1c605cdb487b..8b72fe665c214 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -63,7 +63,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } @@ -93,7 +93,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5d93086082189..fb97e650ff95c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,6 +23,7 @@ import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStr import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols +import java.util.concurrent.TimeUnit import java.util.Locale import com.google.common.base.Charsets.UTF_8 @@ -35,7 +36,50 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { + + test("timeConversion") { + // Test -1 + assert(Utils.timeStringAsSeconds("-1") === -1) + + // Test zero + assert(Utils.timeStringAsSeconds("0") === 0) + + assert(Utils.timeStringAsSeconds("1") === 1) + assert(Utils.timeStringAsSeconds("1s") === 1) + assert(Utils.timeStringAsSeconds("1000ms") === 1) + assert(Utils.timeStringAsSeconds("1000000us") === 1) + assert(Utils.timeStringAsSeconds("1m") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1min") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1h") === TimeUnit.HOURS.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1d") === TimeUnit.DAYS.toSeconds(1)) + + assert(Utils.timeStringAsMs("1") === 1) + assert(Utils.timeStringAsMs("1ms") === 1) + assert(Utils.timeStringAsMs("1000us") === 1) + assert(Utils.timeStringAsMs("1s") === TimeUnit.SECONDS.toMillis(1)) + assert(Utils.timeStringAsMs("1m") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1min") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1h") === TimeUnit.HOURS.toMillis(1)) + assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) + + // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600s") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600ds") + } + intercept[NumberFormatException] { + Utils.timeStringAsMs("600s This breaks") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This 123s breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -106,7 +150,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val second = 1000 val minute = second * 60 val hour = minute * 60 - def str = Utils.msDurationToString(_) + def str: (Long) => String = Utils.msDurationToString(_) val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() @@ -199,7 +243,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() - val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories + // The parent directory has two child directories + val child1: File = Utils.createTempDir(parent.getCanonicalPath) val child2: File = Utils.createTempDir(parent.getCanonicalPath) val child3: File = Utils.createTempDir(child1.getCanonicalPath) // set the last modified time of child1 to 30 secs old diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 794a55d61750b..ce2968728a996 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite @deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { - def verifyVector(vector: Vector, expectedLength: Int) = { + def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) assert(vector.elements.min > 0.0) assert(vector.elements.max < 1.0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 48f79ea651018..dff8f3ddc816f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -185,7 +185,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) - val result1 = rdd.reduceByKey(_+_).collect() + val result1 = rdd.reduceByKey(_ + _).collect() assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 72d96798b1141..9ff067f86af44 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -553,10 +553,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = - buffer1 ++= buffer2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) + : ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) @@ -633,14 +633,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: Int) = ArrayBuffer[Int](i) - def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]): ArrayBuffer[Int] = { + buf1 ++= buf2 + } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll( + (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -654,9 +657,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]): ArrayBuffer[String] = + buf1 ++= buf2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) @@ -720,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) val wrongOrdering = new Ordering[String] { - override def compare(a: String, b: String) = { + override def compare(a: String, b: String): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() h1 - h2 @@ -742,9 +746,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using aggregation and external spill to make sure ExternalSorter using // partitionKeyComparator. - def createCombiner(i: String) = ArrayBuffer(i) - def mergeValue(c: ArrayBuffer[String], i: String) = c += i - def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String): ArrayBuffer[String] = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]): ArrayBuffer[String] = + c1 ++= c2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index ef7178bcdf5c2..03f5f2d1b8528 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -28,7 +28,7 @@ import scala.language.reflectiveCalls class XORShiftRandomSuite extends FunSuite with Matchers { - def fixture = new { + def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index d888de929fdda..cc86ef45858c9 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -36,8 +36,10 @@ object SparkSqlExample { val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ import sqlContext._ - val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) + + val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)).toDF() people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() diff --git a/dev/check-license b/dev/check-license index 39943f882b6ca..10740cfdc5242 100755 --- a/dev/check-license +++ b/dev/check-license @@ -24,29 +24,27 @@ acquire_rat_jar () { JAR="$rat_jar" - if [[ ! -f "$rat_jar" ]]; then - # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if [ $(command -v curl) ]; then - curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif [ $(command -v wget) ]; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi - fi - - unzip -tq $JAR &> /dev/null - if [ $? -ne 0 ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + # Download rat launch jar if it hasn't been downloaded yet + if [ ! -f "$JAR" ]; then + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi - printf "Launching rat from ${JAR}\n" + fi + + unzip -tq "$JAR" &> /dev/null + if [ $? -ne 0 ]; then + # We failed to download + rm "$JAR" + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi } @@ -71,6 +69,11 @@ mkdir -p "$FWDIR"/lib $java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt +if [ $? -ne 0 ]; then + echo "RAT exited abnormally" + exit 1 +fi + ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then diff --git a/dev/run-tests b/dev/run-tests index 561d7fc9e7b1f..bb21ab6c9aa04 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -173,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_BUILD build/mvn $HIVE_BUILD_ARGS clean package -DskipTests else echo -e "q\n" \ - | build/sbt $HIVE_BUILD_ARGS package assembly/assembly \ + | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" fi } @@ -236,3 +236,18 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests + +echo "" +echo "=========================================================================" +echo "Running SparkR tests" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi + diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 8ab6db6925d6e..154e01255b2ef 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14 readonly BLOCK_MIMA=15 readonly BLOCK_SPARK_UNIT_TESTS=16 readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f10aa6b59e1af..3c1c91a111357 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -161,6 +161,10 @@ pr_message="" # Ensure we save off the current HEAD to revert to current_pr_head="`git rev-parse HEAD`" +echo "HEAD: `git rev-parse HEAD`" +echo "GHPRB: $ghprbActualCommit" +echo "SHA1: $sha1" + # Run pull request tests for t in "${PR_TESTS[@]}"; do this_test="${FWDIR}/dev/tests/${t}.sh" @@ -210,6 +214,8 @@ done failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi diff --git a/dev/scalastyle b/dev/scalastyle index 86919227ed1ab..4e03f89ed5d5d 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -18,9 +18,10 @@ # echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | build/sbt -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ - >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 test:scalastyle >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh index 370c7cc737bbd..fdfb3c62aff58 100755 --- a/dev/tests/pr_new_dependencies.sh +++ b/dev/tests/pr_new_dependencies.sh @@ -39,12 +39,12 @@ CURR_CP_FILE="my-classpath.txt" MASTER_CP_FILE="master-classpath.txt" # First switch over to the master branch -git checkout master &>/dev/null +git checkout -f master # Find and copy all pom.xml files into a *.gate file that we can check # against through various `git` changes find -name "pom.xml" -exec cp {} {}.gate \; # Switch back to the current PR -git checkout "${current_pr_head}" &>/dev/null +git checkout -f "${current_pr_head}" # Check if any *.pom files from the current branch are different from the master difference_q="" @@ -71,7 +71,7 @@ else sort > ${CURR_CP_FILE} # Checkout the master branch to compare against - git checkout master &>/dev/null + git checkout -f master ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ sed -n -e '/Building Spark Project Assembly/,$p' | \ @@ -84,7 +84,7 @@ else rev | \ sort > ${MASTER_CP_FILE} - DIFF_RESULTS="`diff my-classpath.txt master-classpath.txt`" + DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" if [ -z "${DIFF_RESULTS}" ]; then echo " * This patch does not change any dependencies." diff --git a/docs/README.md b/docs/README.md index 3773ea25c8b67..5852f972a051d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,13 +58,19 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools + +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a @@ -72,5 +78,5 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 2e88b3093652d..b92c75f90b11c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -84,6 +84,7 @@
  • Scala
  • Java
  • Python
  • +
  • R
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f54b..0ea3f8eab461b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f02..7079de546e2f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/configuration.md b/docs/configuration.md index 7fe11475212b3..7169ec295ef7f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,9 +35,19 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually -require one to prevent any sort of starvation issues. +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +actually require one to prevent any sort of starvation issues. +Properties that specify some time duration should be configured with a unit of time. +The following format is accepted: + + 25ms (milliseconds) + 5s (seconds) + 10m or 10min (minutes) + 3h (hours) + 5d (days) + 1y (years) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -429,10 +439,10 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.io.retryWait - 5 + 5s - (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds. + (Netty only) How long to wait between retries of fetches. The maximum delay caused by retrying + is 15 seconds by default, calculated as maxRetries * retryWait. @@ -732,17 +742,17 @@ Apart from these, the following properties are also available, and may be useful spark.executor.heartbeatInterval - 10000 - Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + 10s + Interval between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress tasks. spark.files.fetchTimeout - 60 + 60s Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver, in seconds. + the driver. @@ -853,11 +863,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval - 1000 + 1000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` if you need to. A likely positive use case for using failure detector would be: a sensistive failure detector can help evict rogue executors quickly. However this is usually not the case @@ -868,11 +878,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.pauses - 6000 + 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart - beat pause in seconds for Akka. This can be used to control sensitivity to GC pauses. Tune + beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune this along with `spark.akka.heartbeat.interval` if you need to. @@ -886,9 +896,9 @@ Apart from these, the following properties are also available, and may be useful spark.akka.timeout - 100 + 100s - Communication timeout between Spark nodes, in seconds. + Communication timeout between Spark nodes. @@ -938,10 +948,10 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout - 120 + 120s - Default timeout for all network interactions, in seconds. This config will be used in - place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, + Default timeout for all network interactions. This config will be used in place of + spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs or spark.shuffle.io.connectionTimeout, if they are not configured. @@ -989,9 +999,9 @@ Apart from these, the following properties are also available, and may be useful spark.locality.wait - 3000 + 3s - Number of milliseconds to wait to launch a data-local task before giving up and launching it + How long to wait to launch a data-local task before giving up and launching it on a less-local node. The same wait will be used to step through multiple locality levels (process-local, node-local, rack-local and then any). It is also possible to customize the waiting time for each level by setting spark.locality.wait.node, etc. @@ -1024,10 +1034,9 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.maxRegisteredResourcesWaitingTime - 30000 + 30s - Maximum amount of time to wait for resources to register before scheduling begins - (in milliseconds). + Maximum amount of time to wait for resources to register before scheduling begins. @@ -1054,10 +1063,9 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.revive.interval - 1000 + 1s - The interval length for the scheduler to revive the worker resource offers to run tasks - (in milliseconds). + The interval length for the scheduler to revive the worker resource offers to run tasks. @@ -1070,9 +1078,9 @@ Apart from these, the following properties are also available, and may be useful spark.speculation.interval - 100 + 100ms - How often Spark will check for tasks to speculate, in milliseconds. + How often Spark will check for tasks to speculate. @@ -1127,10 +1135,10 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout - 600 + 600s - If dynamic allocation is enabled and an executor has been idle for more than this duration - (in seconds), the executor will be removed. For more detail, see this + If dynamic allocation is enabled and an executor has been idle for more than this duration, + the executor will be removed. For more detail, see this description. @@ -1157,10 +1165,10 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.schedulerBacklogTimeout - 5 + 5s If dynamic allocation is enabled and there have been pending tasks backlogged for more than - this duration (in seconds), new executors will be requested. For more detail, see this + this duration, new executors will be requested. For more detail, see this description. @@ -1215,18 +1223,18 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.ack.wait.timeout - 60 + 60s - Number of seconds for the connection to wait for ack to occur before timing + How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, you can set larger value. spark.core.connection.auth.wait.timeout - 30 + 30s - Number of seconds for the connection to wait for authentication to occur before timing + How long for the connection to wait for authentication to occur before timing out and giving up. @@ -1347,9 +1355,9 @@ Apart from these, the following properties are also available, and may be useful Property NameDefaultMeaning spark.streaming.blockInterval - 200 + 200ms - Interval (milliseconds) at which data received by Spark Streaming receivers is chunked + Interval at which data received by Spark Streaming receivers is chunked into blocks of data before storing them in Spark. Minimum recommended - 50 ms. See the performance tuning section in the Spark Streaming programing guide for more details. diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e754..317554c5f2a5b 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904d..1b90d7ec5a7ae 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c08c76d226713..771a07183e26f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.sql import Row, SQLContext sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index b7e68d4f71714..853c9f26b0ec9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -48,9 +48,9 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.am.waitTime - 100000 + 100s - In yarn-cluster mode, time in milliseconds for the application master to wait for the + In yarn-cluster mode, time for the application master to wait for the SparkContext to be initialized. In yarn-client mode, time for the application master to wait for the driver to connect to it. @@ -87,7 +87,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4441d6a000a02..332618edf0c55 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1642,7 +1642,7 @@ moved into the udf object in `SQLContext`.
    {% highlight java %} -sqlCtx.udf.register("strLen", (s: String) => s.length()) +sqlContext.udf.register("strLen", (s: String) => s.length()) {% endhighlight %}
    @@ -1650,7 +1650,7 @@ sqlCtx.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlCtx.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); {% endhighlight %}
    @@ -1784,6 +1784,7 @@ in Hive deployments. **Esoteric Hive Features** + * `UNION` type * Unique join * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 879a52cef8ff0..0c1f24761d0de 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -282,6 +282,10 @@ def parse_args(): parser.add_option( "--vpc-id", default=None, help="VPC to launch instances in") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -707,7 +711,7 @@ def get_instances(group_names): # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: print "Generating cluster's SSH key on master..." key_setup = """ @@ -719,8 +723,9 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) print "Transferring cluster's SSH key to slaves..." for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print slave_address + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone', 'tachyon'] @@ -809,7 +814,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.public_dns_name, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -923,7 +929,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -948,10 +954,12 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): print "Deploying Spark via git hash; Tachyon won't be set up" modules = filter(lambda x: x != "tachyon", modules) + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -1011,7 +1019,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): # # root_dir should be an absolute path. def deploy_user_files(root_dir, opts, master_nodes): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) command = [ 'rsync', '-rv', '-e', stringify_command(ssh_command(opts)), @@ -1122,6 +1130,20 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1230,7 +1252,7 @@ def real_main(): if any(master_nodes + slave_nodes): print "The following instances will be terminated:" for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name + print "> %s" % get_dns_name(inst, opts.private_ips) print "ALL DATA ON ALL NODES WILL BE LOST!!" msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) @@ -1294,13 +1316,17 @@ def real_main(): elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print "Logging into master " + master + "..." + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1318,7 +1344,11 @@ def real_main(): elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + print get_dns_name(master_nodes[0], opts.private_ips) elif action == "stop": response = raw_input( diff --git a/examples/pom.xml b/examples/pom.xml index 7e93f0eec0b91..afd7c6d52f0dd 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -90,6 +90,12 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + org.apache.hbase @@ -234,6 +240,7 @@ org.apache.commons commons-math3 + provided com.twitter @@ -262,6 +269,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -270,10 +293,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -281,6 +316,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -322,12 +368,6 @@ - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 19d0eb216848e..eaf00d09f550d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -116,7 +116,7 @@ class MyJavaLogisticRegression */ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); - int getMaxIter() { return (Integer) get(maxIter); } + int getMaxIter() { return (Integer) getOrDefault(maxIter); } public MyJavaLogisticRegression() { setMaxIter(100); @@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) { public MyJavaLogisticRegressionModel copy() { MyJavaLogisticRegressionModel m = new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); - Params$.MODULE$.inheritValues(this.paramMap(), this, m); + Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); return m; } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index dee794840a3e1..8159ffbe2d269 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,12 +99,12 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index d281f4fa44282..c73edb7fd6b20 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -33,7 +33,7 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3c..fcbf56cbf0c52 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -44,19 +44,19 @@ def summarize(dataset): print >> sys.stderr, "Usage: dataset_example.py " exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) print "Save dataset as a Parquet file to %s." % tempdir dataset0.saveAsParquetFile(tempdir) print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 0000000000000..6e6b5cb93789c --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 0000000000000..c864a4232d010 --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 0000000000000..2a86aa98160d3 --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 0000000000000..aa7a833e147a0 --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 0000000000000..b734cb0ecf55b --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index df26798e41b7b..2245fa429fda3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) } /** @@ -174,11 +174,11 @@ private class MyLogisticRegressionModel( * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * - * This is used for the defaul implementation of [[transform()]]. + * This is used for the default implementation of [[transform()]]. */ override protected def copy(): MyLogisticRegressionModel = { val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e99817224..42df8792f147f 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e04d4088df7dc..2edea9b5b69ba 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 51d273af8da84..39e6754c81dbf 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 2f7e0ab39fefd..bd767031c1849 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -123,9 +123,17 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { val errs = new Err withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp: TopicMetadataResponse = consumer.send(req) - // error codes here indicate missing / just created topic, - // repeating on a different broker wont be useful - return Right(resp.topicsMetadata.toSet) + val respErrs = resp.topicsMetadata.filter(m => m.errorCode != ErrorMapping.NoError) + + if (respErrs.isEmpty) { + return Right(resp.topicsMetadata.toSet) + } else { + respErrs.foreach { m => + val cause = ErrorMapping.exceptionFor(m.errorCode) + val msg = s"Error getting partition metadata for '${m.topic}'. Does the topic exist?" + errs.append(new SparkException(msg, cause)) + } + } } Left(errs) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala new file mode 100644 index 0000000000000..13e9475065979 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap} +import java.util.Properties +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} +import kafka.serializer.StringEncoder +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.ZKStringSerializer +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkClient: ZkClient = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkClient = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkClient).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration) + server = new KafkaServer(brokerConf) + server.startup() + (server, port) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + AdminUtils.createTopic(zkClient, topic, 1, 1) + // wait until metadata is propagated + waitUntilMetadataIsPropagated(topic, 0) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + import scala.collection.JavaConversions._ + sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new Producer[String, String](new ProducerConfig(producerConfiguration)) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerAddress) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + eventually(Time(10000), Time(100)) { + assert( + server.apis.metadataCache.containsTopicAndPartition(topic, partition), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index d6ca6d58b5665..4c1d6a03eb2b8 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -41,24 +41,28 @@ public class JavaDirectKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); - SparkConf sparkConf = new SparkConf() - .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); - ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); } @After public void tearDown() { + if (ssc != null) { ssc.stop(); ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -74,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { sent.addAll(Arrays.asList(topic2data)); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); JavaDStream stream1 = KafkaUtils.createDirectStream( @@ -147,8 +151,8 @@ private HashMap topicOffsetToMap(String topic, Long off private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 4477b81827c70..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -37,13 +37,12 @@ public class JavaKafkaRDDSuite implements Serializable { private transient JavaSparkContext sc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); sc = new JavaSparkContext(sparkConf); @@ -51,10 +50,15 @@ public void setUp() { @After public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -66,7 +70,7 @@ public void testKafkaRDD() throws InterruptedException { String[] topic2data = createTopicAndSendData(topic2); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), @@ -75,7 +79,7 @@ public void testKafkaRDD() throws InterruptedException { HashMap emptyLeaders = new HashMap(); HashMap leaders = new HashMap(); - String[] hostAndPort = suiteBase.brokerAddress().split(":"); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); leaders.put(new TopicAndPartition(topic2, 0), broker); @@ -144,8 +148,8 @@ public String call(MessageAndMetadata msgAndMd) throws Exception private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index bad0a93eb2e84..540f4ceabab47 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -22,9 +22,7 @@ import java.util.List; import java.util.Random; -import scala.Predef; import scala.Tuple2; -import scala.collection.JavaConverters; import kafka.serializer.StringDecoder; import org.junit.After; @@ -44,13 +42,12 @@ public class JavaKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; private transient Random random = new Random(); - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); ssc = new JavaStreamingContext(sparkConf, new Duration(500)); @@ -58,10 +55,15 @@ public void setUp() { @After public void tearDown() { - ssc.stop(); - ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -75,15 +77,11 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - suiteBase.createTopic(topic); - HashMap tmp = new HashMap(sent); - suiteBase.sendMessages(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms()) - ); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, sent); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -126,6 +124,7 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); + long startTime = System.currentTimeMillis(); boolean sizeMatches = false; while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { @@ -136,6 +135,5 @@ public Void call(JavaPairRDD rdd) throws Exception { for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } - ssc.stop(); } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 17ca9d145d665..415730f5559c5 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -27,31 +27,41 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -class DirectKafkaStreamSuite extends KafkaStreamSuiteBase - with BeforeAndAfter with BeforeAndAfterAll with Eventually { +class DirectKafkaStreamSuite + extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ - var ssc: StreamingContext = _ - var testDir: File = _ + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } after { @@ -72,12 +82,12 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topics = Set("basic1", "basic2", "basic3") val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } val totalSent = data.values.sum * topics.size val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) @@ -121,9 +131,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "largest" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -132,7 +142,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() > 3) } @@ -154,7 +164,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -166,9 +176,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "offset" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -177,7 +187,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() >= 10) } @@ -200,7 +210,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -210,18 +220,18 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase // Test to verify the offset ranges can be recovered from the checkpoints test("offset recovery") { val topic = "recovery" - createTopic(topic) + kafkaTestUtils.createTopic(topic) testDir = Utils.createTempDir() val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) // Send data to Kafka and wait for it to be received def sendDataAndWaitForReceive(data: Seq[Int]) { val strings = data.map { _.toString} - sendMessages(topic, strings.map { _ -> 1}.toMap) + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) eventually(timeout(10 seconds), interval(50 milliseconds)) { assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index fc9275b7207be..7fb841b79cb65 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,31 +20,41 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} -class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val topic = "kcsuitetopic" + Random.nextInt(10000) - val topicAndPartition = TopicAndPartition(topic, 0) - var kc: KafkaCluster = null +class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { + private val topic = "kcsuitetopic" + Random.nextInt(10000) + private val topicAndPartition = TopicAndPartition(topic, 0) + private var kc: KafkaCluster = null + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { - setupKafka() - createTopic(topic) - sendMessages(topic, Map("a" -> 1)) - kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress")) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) } override def afterAll() { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("metadata apis") { val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) val leaderAddress = s"${leader._1}:${leader._2}" - assert(leaderAddress === brokerAddress, "didn't get leader") + assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") val parts = kc.getPartitions(Set(topic)).right.get assert(parts(topicAndPartition), "didn't get partitions") + + val err = kc.getPartitions(Set(topic + "BAD")) + assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") } test("leader offset apis") { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index a223da70b043f..7d26ce50875b3 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,18 +22,22 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark._ -import org.apache.spark.SparkContext._ -class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ +class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + override def beforeAll { sc = new SparkContext(sparkConf) - - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { @@ -41,17 +45,21 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { sc.stop sc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("basic usage") { val topic = "topicbasic" - createTopic(topic) + kafkaTestUtils.createTopic(topic) val messages = Set("the", "quick", "brown", "fox") - sendMessages(topic, messages.toArray) + kafkaTestUtils.sendMessages(topic, messages.toArray) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -67,15 +75,15 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val kc = new KafkaCluster(kafkaParams) // this is the "lots of messages" case - sendMessages(topic, sent) + kafkaTestUtils.sendMessages(topic, sent) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -92,14 +100,14 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // shouldn't get anything, since message is sent after rdd was defined val sentOnlyOne = Map("d" -> 1) - sendMessages(topic, sentOnlyOne) + kafkaTestUtils.sendMessages(topic, sentOnlyOne) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above val rdd3 = getRdd(kc, Set(topic)) // send lots of messages after rdd was defined, they shouldn't show up - sendMessages(topic, Map("extra" -> 22)) + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) assert(rdd3.isDefined) assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e4966eebb9b34..24699dfc33adb 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,209 +17,38 @@ package org.apache.spark.streaming.kafka -import java.io.File -import java.net.InetSocketAddress -import java.util.Properties - import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.AdminUtils -import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, Producer, ProducerConfig} -import kafka.serializer.{StringDecoder, StringEncoder} -import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient.ZkClient -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.scalatest.{BeforeAndAfter, FunSuite} +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils - -/** - * This is an abstract base class for Kafka testsuites. This has the functionality to set up - * and tear down local Kafka servers, and to push data using Kafka producers. - */ -abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - - private val zkHost = "localhost" - private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 - private val zkSessionTimeout = 6000 - private var zookeeper: EmbeddedZookeeper = _ - private val brokerHost = "localhost" - private var brokerPort = 9092 - private var brokerConf: KafkaConfig = _ - private var server: KafkaServer = _ - private var producer: Producer[String, String] = _ - private var zkReady = false - private var brokerReady = false - - protected var zkClient: ZkClient = _ - - def zkAddress: String = { - assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") - s"$zkHost:$zkPort" - } - def brokerAddress: String = { - assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") - s"$brokerHost:$brokerPort" - } - - def setupKafka() { - // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") - // Get the actual zookeeper binding port - zkPort = zookeeper.actualPort - zkReady = true - logInfo("==================== Zookeeper Started ====================") +class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { + private var ssc: StreamingContext = _ + private var kafkaTestUtils: KafkaTestUtils = _ - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) - logInfo("==================== Zookeeper Client Created ====================") - - // Kafka broker startup - var bindSuccess: Boolean = false - while(!bindSuccess) { - try { - val brokerProps = getBrokerConfig() - brokerConf = new KafkaConfig(brokerProps) - server = new KafkaServer(brokerConf) - server.startup() - logInfo("==================== Kafka Broker Started ====================") - bindSuccess = true - } catch { - case e: KafkaException => - if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { - brokerPort += 1 - } - case e: Exception => throw new Exception("Kafka server create failed", e) - } - } - - Thread.sleep(2000) - logInfo("==================== Kafka + Zookeeper Ready ====================") - brokerReady = true + override def beforeAll(): Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } - def tearDownKafka() { - brokerReady = false - zkReady = false - if (producer != null) { - producer.close() - producer = null - } - - if (server != null) { - server.shutdown() - server = null - } - - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - if (zkClient != null) { - zkClient.close() - zkClient = null - } - - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } - - def createTopic(topic: String) { - AdminUtils.createTopic(zkClient, topic, 1, 1) - // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) - logInfo(s"==================== Topic $topic Created ====================") - } - - def sendMessages(topic: String, messageToFreq: Map[String, Int]) { - val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray - sendMessages(topic, messages) - } - - def sendMessages(topic: String, messages: Array[String]) { - producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) - producer.close() - logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================") - } - - private def getBrokerConfig(): Properties = { - val props = new Properties() - props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkAddress) - props.put("log.flush.interval.messages", "1") - props.put("replica.socket.timeout.ms", "1500") - props - } - - private def getProducerConfig(): Properties = { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - val props = new Properties() - props.put("metadata.broker.list", brokerAddr) - props.put("serializer.class", classOf[StringEncoder].getName) - props - } - - private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert( - server.apis.metadataCache.containsTopicAndPartition(topic, partition), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) - } - } - - class EmbeddedZookeeper(val zkConnect: String) { - val random = new Random() - val snapshotDir = Utils.createTempDir() - val logDir = Utils.createTempDir() - - val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) - val (ip, port) = { - val splits = zkConnect.split(":") - (splits(0), splits(1).toInt) - } - val factory = new NIOServerCnxnFactory() - factory.configure(new InetSocketAddress(ip, port), 16) - factory.startup(zookeeper) - - val actualPort = factory.getLocalPort - - def shutdown() { - factory.shutdown() - Utils.deleteRecursively(snapshotDir) - Utils.deleteRecursively(logDir) - } - } -} - - -class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { - var ssc: StreamingContext = _ - - before { - setupKafka() - } - - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("Kafka input stream") { @@ -227,10 +56,10 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - sendMessages(topic, sent) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkAddress, + val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -244,14 +73,14 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { result.put(kv._1, count) } } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent.size === result.size) sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } } - ssc.stop() } } - diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 3cd960d1fd1d4..38548dd73b82c 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import java.io.File import scala.collection.mutable @@ -27,7 +26,7 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkConf @@ -35,47 +34,61 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { +class ReliableKafkaStreamSuite extends FunSuite + with BeforeAndAfterAll with BeforeAndAfter with Eventually { - val sparkConf = new SparkConf() + private val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.receiver.writeAheadLog.enable", "true") - val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private var kafkaTestUtils: KafkaTestUtils = _ - var groupId: String = _ - var kafkaParams: Map[String, String] = _ - var ssc: StreamingContext = _ - var tempDirectory: File = null + private var groupId: String = _ + private var kafkaParams: Map[String, String] = _ + private var ssc: StreamingContext = _ + private var tempDirectory: File = null + + override def beforeAll() : Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() - before { - setupKafka() groupId = s"test-consumer-${Random.nextInt(10000)}" kafkaParams = Map( - "zookeeper.connect" -> zkAddress, + "zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> groupId, "auto.offset.reset" -> "smallest" ) - ssc = new StreamingContext(sparkConf, Milliseconds(500)) tempDirectory = Utils.createTempDir() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDirectory) + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + before { + ssc = new StreamingContext(sparkConf, Milliseconds(500)) ssc.checkpoint(tempDirectory.getAbsolutePath) } after { if (ssc != null) { ssc.stop() + ssc = null } - Utils.deleteRecursively(tempDirectory) - tearDownKafka() } - test("Reliable Kafka input stream with single topic") { - var topic = "test-topic" - createTopic(topic) - sendMessages(topic, data) + val topic = "test-topic" + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -91,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter } } ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { // A basic process verification for ReliableKafkaReceiver. // Verify whether received message number is equal to the sent message number. @@ -100,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter // Verify the offset number whether it is equal to the total message number. assert(getCommitOffset(groupId, topic, 0) === Some(29L)) } - ssc.stop() } test("Reliable Kafka input stream with multiple topics") { val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. @@ -118,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) stream.foreachRDD(_ => Unit) ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { // Verify the offset for each group/topic to see whether they are equal to the expected one. topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } } - ssc.stop() } /** Getting partition offset from Zookeeper. */ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { - assert(zkClient != null, "Zookeeper client is not initialized") val topicDirs = new ZKGroupTopicDirs(groupId, topic) val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" - ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong) } } diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 24d78ecb3a97d..a19a72c58a705 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { msgTopic.publish(message) } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) } } } diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088f..9a3569789d2e0 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b6123..6cdc9286c5d76 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1bd1f324298e7..a7fe4476cacb8 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,6 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils import com.amazonaws.auth.AWSCredentialsProvider import com.amazonaws.auth.DefaultAWSCredentialsProviderChain @@ -118,7 +119,7 @@ private[kinesis] class KinesisReceiver( * method. */ override def onStart() { - workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + workerId = Utils.localHostName() + ":" + UUID.randomUUID() credentialsProvider = new DefaultAWSCredentialsProviderChain() kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 1a7178b82e3af..3b0e1628d86b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -93,7 +93,7 @@ object SVDPlusPlus { val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => - (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) + (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) }.cache() materialize(gJoinT0) g.unpersist() diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 8d15150458d26..a570e4ed75fc3 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -38,12 +38,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { val doubleRing = ring ++ ring val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1) assert(graph.edges.count() === doubleRing.size) - assert(graph.edges.collect.forall(e => e.attr == 1)) + assert(graph.edges.collect().forall(e => e.attr == 1)) // uniqueEdges option should uniquify edges and store duplicate count in edge attributes val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut)) assert(uniqueGraph.edges.count() === ring.size) - assert(uniqueGraph.edges.collect.forall(e => e.attr == 2)) + assert(uniqueGraph.edges.collect().forall(e => e.attr == 2)) } } @@ -64,7 +64,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect().map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } @@ -75,15 +75,17 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet === - (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) + assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet + === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) } } test("partitionBy") { withSpark { sc => - def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) - def nonemptyParts(graph: Graph[Int, Int]) = { + def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = { + Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) + } + def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = { graph.edges.partitionsRDD.mapPartitions { iter => Iterator(iter.next()._2.iterator.toList) }.filter(_.nonEmpty) @@ -102,7 +104,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1) // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into // the same partition - assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) + assert( + nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) // partitionBy(EdgePartition2D) puts identical edges in the same partition assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1) @@ -140,10 +143,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val g = Graph( sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))), sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2)) - assert(g.triplets.collect.map(_.toTuple).toSet === + assert(g.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) val gPart = g.partitionBy(EdgePartition2D) - assert(gPart.triplets.collect.map(_.toTuple).toSet === + assert(gPart.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) } } @@ -154,10 +157,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = starGraph(sc, n) // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length) - assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) + assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -177,12 +180,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Trigger initial vertex replication graph0.triplets.foreach(x => {}) // Change type of replicated vertices, but preserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + val graph1 = graph0.mapVertices { case (vid, integerOpt) => + integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double) } // Access replicated vertices, exposing the erased type val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) } } @@ -202,7 +205,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet === + assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet === (1L to n).map(x => Edge(0, x, "vv")).toSet) } } @@ -211,7 +214,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) + assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -221,7 +224,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect.toSet === Set((1L, 2))) + assert(result.collect().toSet === Set((1L, 2))) } } @@ -237,7 +240,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet) // And 4 edges. - assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet) + assert(subgraph.edges.map(_.copy()).collect().toSet === + (2 to n by 2).map(x => Edge(0, x, 1)).toSet) } } @@ -273,9 +277,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 to n).flatMap(x => List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v") val star2 = doubleStar.groupEdges { (a, b) => a} - assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) === - star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int])) - assert(star2.vertices.collect.toSet === star.vertices.collect.toSet) + assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) === + star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int])) + assert(star2.vertices.collect().toSet === star.vertices.collect().toSet) } } @@ -300,21 +304,23 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3) + val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } + val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => + newOpt.getOrElse(old) + } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set if (et.srcId % 2 != 1) { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) } @@ -340,17 +346,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = - reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } + val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) + } val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect.toSet + (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === + assert(newReverseStar.vertices.map(_._2).collect().toSet === (0 to n).map(x => "v%d".format(x)).toSet) } } @@ -361,7 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) val graph = Graph(verts, edges) val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)) - .collect.toSet + .collect().toSet assert(triplets === Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a"))) } @@ -417,7 +424,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val neighborAttrSums = graph.mapReduceTriplets[Int]( et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { sc.stop() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index a3e28efc75a98..d2ad9be555770 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T) = { + def withSpark[T](f: SparkContext => T): T = { val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index c9443d11c76cf..d0a7198d691d7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.storage.StorageLevel class VertexRDDSuite extends FunSuite with LocalSparkContext { - def vertices(sc: SparkContext, n: Int) = { + private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) } @@ -52,7 +52,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -62,7 +62,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB: RDD[(VertexId, Int)] = sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -72,7 +72,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet) } } @@ -106,7 +106,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.diff(vertexB) - assert(vertexC.map(_._1).collect.toSet === (8 until 16).toSet) + assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet) } } @@ -116,11 +116,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD - assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) // leftJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) } } @@ -134,7 +134,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) => old - newOpt.getOrElse(0) } - assert(vertexC.filter(v => v._2 != 0).map(_._1).collect.toSet == (1 to 99 by 2).toSet) + assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet) } } @@ -144,11 +144,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD - assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) // innerJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) } } @@ -161,7 +161,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) => old - newVal } - assert(vertexC.filter(v => v._2 == 0).map(_._1).collect.toSet == (0 to 98 by 2).toSet) + assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet) } } @@ -171,7 +171,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n) val messageTargets = (0 to n) ++ (0 to n by 2) val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1))) - assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet === + assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet === (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet) } } @@ -183,7 +183,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) // test merge function - assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 3915be15b3434..4cc30a96408f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -32,7 +32,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -42,7 +42,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -50,8 +50,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() @@ -73,12 +73,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if (id < 10) { assert(cc === 0) @@ -120,9 +120,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { // Build the initial Graph val graph = Graph(users, relationships, defaultUser) val ccGraph = graph.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - assert(cc == 0) + assert(cc === 0) } } } // end of toy connected components diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index fc491ae327c2a..95804b07b1db0 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -19,15 +19,12 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ + object GridPageRank { - def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { + def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = { val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int]) val outDegree = Array.fill(nRows * nCols)(0) // Convert row column address into vertex ids (row major order) @@ -35,13 +32,13 @@ object GridPageRank { // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { val ind = sub2ind(r,c) - if (r+1 < nRows) { + if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r+1,c)) += ind + inNbrs(sub2ind(r + 1,c)) += ind } - if (c+1 < nCols) { + if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c+1)) += ind + inNbrs(sub2ind(r,c + 1)) += ind } } // compute the pagerank @@ -64,7 +61,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } - .map { case (id, error) => error }.sum + .map { case (id, error) => error }.sum() } test("Star PageRank") { @@ -80,12 +77,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum + }.map { case (vid, test) => test }.sum() assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) + val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) + val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) @@ -95,8 +92,6 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank - - test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -109,18 +104,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() + val referenceRanks = VertexRDD( + sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) } } // end of Grid PageRank - test("Chain PageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index df54aa37cad68..1f658c371ffcf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -34,8 +34,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) val sccGraph = graph.stronglyConnectedComponents(5) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(id === scc) } } } @@ -45,8 +45,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(0L == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(0L === scc) } } } @@ -60,13 +60,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - if (id < 3) - assert(0L == scc) - else if (id < 6) - assert(3L == scc) - else - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + if (id < 3) { + assert(0L === scc) + } else if (id < 6) { + assert(3L === scc) + } else { + assert(id === scc) + } } } } diff --git a/launcher/pom.xml b/launcher/pom.xml index 0fe2814135d88..182e5f60218db 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -52,11 +52,6 @@ mockito-all test - - org.scalatest - scalatest_${scala.binary.version} - test - org.slf4j slf4j-api diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 9b04732afee14..f4ebc25bdd32b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -274,14 +274,14 @@ static String quoteForBatchScript(String arg) { } /** - * Quotes a string so that it can be used in a command string and be parsed back into a single - * argument by python's "shlex.split()" function. - * + * Quotes a string so that it can be used in a command string. * Basically, just add simple escapes. E.g.: * original single argument : ab "cd" ef * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. */ - static String quoteForPython(String s) { + static String quoteForCommandString(String s) { StringBuilder quoted = new StringBuilder().append('"'); for (int i = 0; i < s.length(); i++) { int cp = s.codePointAt(i); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 91dcf70f105db..a73c9c87e3126 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -17,14 +17,9 @@ package org.apache.spark.launcher; +import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -53,6 +48,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling @@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = PYSPARK_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); } else { this.allowsMixedArguments = false; } @@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { public List buildCommand(Map env) throws IOException { if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); } @@ -213,36 +228,62 @@ private List buildPySparkShellCommand(Map env) throws IO return buildCommand(env); } - // When launching the pyspark shell, the spark-submit arguments should be stored in the - // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable - // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS. checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); - // Store spark-submit arguments in an environment variable, since there's no way to pass - // them to shell.py on the comand line. StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { if (submitArgs.length() > 0) { submitArgs.append(" "); } - submitArgs.append(quoteForPython(arg)); + submitArgs.append(quoteForCommandString(arg)); } - env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString()); - - List pyargs = new ArrayList(); - pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); - String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); - if (!isEmpty(pyOpts)) { - pyargs.addAll(parseOptionString(pyOpts)); - } - - return pyargs; + env.put(submitArgsEnvVariable, submitArgs.toString()); } + private boolean isClientMode(Properties userProps) { String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index dba0203867372..1ae42eed8a3af 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -79,9 +79,9 @@ public void testWindowsBatchQuoting() { @Test public void testPythonArgQuoting() { - assertEquals("\"abc\"", quoteForPython("abc")); - assertEquals("\"a b c\"", quoteForPython("a b c")); - assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c")); + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } private void testOpt(String opts, List expected) { diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 00c20ad69cd4d..67a6a98217118 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -27,5 +27,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index eff7ef925dfbd..d6b3503ebdd9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { */ @varargs def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + val map = ParamMap(paramPairs: _*) fit(dataset, map) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index cd84b05bfb496..a50090671ae48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable { * random hex chars. */ private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) + this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index c4a36103303a2..8eddf79cdfe28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -47,6 +47,9 @@ abstract class PipelineStage extends Serializable with Logging { /** * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ protected def transformSchema( schema: StructType, @@ -81,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -98,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -135,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -174,14 +177,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9a5848684b179..7fb87fe452ee6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,6 +22,7 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 970e6ad5514d1..aa27a668f1695 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -106,7 +106,7 @@ class AttributeGroup private ( def getAttr(attrIndex: Int): Attribute = this(attrIndex) /** Converts to metadata without name. */ - private[attribute] def toMetadata: Metadata = { + private[attribute] def toMetadataImpl: Metadata = { import AttributeKeys._ val bldr = new MetadataBuilder() if (attributes.isDefined) { @@ -142,17 +142,24 @@ class AttributeGroup private ( bldr.build() } - /** Converts to a StructField with some existing metadata. */ - def toStructField(existingMetadata: Metadata): StructField = { - val newMetadata = new MetadataBuilder() + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, toMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) .build() - StructField(name, new VectorUDT, nullable = false, newMetadata) + } + + /** Converts to ML metadata */ + def toMetadata: Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) } /** Converts to a StructField. */ - def toStructField(): StructField = toStructField(Metadata.empty) + def toStructField: StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index c5fc89f935432..29339c98f51cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -67,8 +69,7 @@ private[spark] abstract class Classifier[ with ClassifierParams { /** @group setParam */ - def setRawPredictionCol(value: String): E = - set(rawPredictionCol, value).asInstanceOf[E] + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) } @@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 49c00f77480e8..cc8b0721cf2b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -31,8 +31,10 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold { + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) +} /** * :: AlphaComponent :: @@ -45,16 +47,15 @@ class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) - /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -67,7 +68,8 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) @@ -96,8 +98,6 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - setThreshold(0.5) - /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -119,7 +119,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -180,7 +180,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > getThreshold) 1 else 0 } override protected def predictProbabilities(features: Vector): Vector = { @@ -195,7 +195,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(this.extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index bd8caac855981..10404548ccfde 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * Params for probabilistic classification. */ @@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2360f4479f1c2..c865eb9fe092d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType - /** * :: AlphaComponent :: * @@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params * @group param */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = get(metricName) + def getMetricName: String = getOrDefault(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema - checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index fc4e12773c46d..b20f2fc49a8f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { * number of features * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ - def getNumFeatures: Int = get(numFeatures) + def getNumFeatures: Int = getOrDefault(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + setDefault(numFeatures -> (1 << 18)) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 05f91dc9105fe..decaeb0da6246 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -35,14 +35,16 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { * Normalization in L^p^ space, p = 2 by default. * @group param */ - val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + val p = new DoubleParam(this, "p", "the p norm value") /** @group getParam */ - def getP: Double = get(p) + def getP: Double = getOrDefault(p) /** @group setParam */ def setP(value: Double): this.type = set(p, value) + setDefault(p -> 2.0) + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform @@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { override protected def outputDataType: DataType = new VectorUDT() } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1142aa4f8e73d..1b102619b3524 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -47,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -56,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -86,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 0000000000000..4d960df357fe9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputFields = schema.fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + */ +@AlphaComponent +class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // TODO: handle unseen labels + + override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { + val map = extractParamMap(paramMap) + val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val model = new StringIndexerModel(this, map, labels) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StringIndexer]]. + */ +@AlphaComponent +class StringIndexerModel private[ml] ( + override val parent: StringIndexer, + override val fittingParamMap: ParamMap, + labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + // TODO: handle unseen labels + throw new SparkException(s"Unseen label: $label.") + } + } + val outputColName = map(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toStructField().metadata + dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 68401e36950bd..376a004858b4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -56,39 +56,39 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param for minimum token length, default is one to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = get(minTokenLength) + def getMinTokenLength: Int = getOrDefault(minTokenLength) /** * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam( - this, "gaps", "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = get(gaps) + def getGaps: Boolean = getOrDefault(gaps) /** * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[String] = new Param( - this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern) + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => val re = paramMap(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala new file mode 100644 index 0000000000000..e567e069e7c0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: AlphaComponent :: + * A feature transformer than merge multiple columns into a vector column. + */ +@AlphaComponent +class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val assembleFunc = udf { r: Row => + VectorAssembler.assemble(r.toSeq: _*) + } + val schema = dataset.schema + val inputColNames = map(inputCols) + val args = inputColNames.map { c => + schema(c).dataType match { + case DoubleType => UnresolvedAttribute(c) + case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) + case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + } + } + dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val inputColNames = map(inputCols) + val outputColName = map(outputCol) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NativeType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + } +} + +@AlphaComponent +object VectorAssembler { + + private[feature] def assemble(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 0000000000000..452faa06e2021 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, + Attribute, AttributeGroup} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take." + + " If a feature is found to have > maxCategories values, then it is declared continuous.") + + /** @group getParam */ + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) +} + +/** + * :: AlphaComponent :: + * + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@AlphaComponent +class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = { + require(value > 1, + s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") + set(maxCategories, value) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val firstRow = dataset.select(map(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val maxCats = map(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) + } +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + while (i < dv.size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + while (vecIndex < sv.size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: AlphaComponent :: + * + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@AlphaComponent +class VectorIndexerModel private[ml] ( + override val parent: VectorIndexer, + override val fittingParamMap: ParamMap, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + private val transformFunc: Vector => Vector = { + val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val f: Vector => Vector = { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val newField = prepOutputField(dataset.schema, map) + val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) + // For now, just check the first row of inputCol for vector length. + val firstRow = dataset.select(map(inputCol)).take(1) + if (firstRow.length != 0) { + val actualNumFeatures = firstRow(0).getAs[Vector](0).size + require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length $actualNumFeatures") + } + dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(map.contains(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema, map) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @param map Parameter map (with this class' embedded parameter map folded in) + * @return Output column field + */ + private def prepOutputField(schema: StructType, map: ParamMap): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index dfb89cc8d4af3..195333a5cc47f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -53,14 +55,14 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - checkInputColumn(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) } - addOutputColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) } } @@ -98,7 +100,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -141,7 +143,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -201,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17ece897a6c55..849c60433c777 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable -import org.apache.spark.sql.types.{DataType, StructField, StructType} - /** * :: AlphaComponent :: @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -55,58 +49,55 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { - s"$name: $doc" + "(undefined)" } + s"$name: $doc $valueStr" } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) - extends Param[Double](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) - extends Param[Int](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) - extends Param[Float](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) - extends Param[Long](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) - extends Param[Boolean](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -124,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -153,25 +147,29 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - protected def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) + protected final def set[T](param: Param[T], value: T): this.type = { + shouldOwn(param) paramMap.put(param.asInstanceOf[Param[Any]], value) this } @@ -179,44 +177,102 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** - * Gets the value of a parameter in the embedded param map. + * Optionally returns the user-supplied value of a param. */ - protected def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) } /** - * Internal param map. + * Clears the user-supplied value for the input param. */ - protected val paramMap: ParamMap = ParamMap.empty + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } /** - * Check whether the given schema contains an input column. - * @param colName Parameter name for the input column. - * @param dataType SQL DataType of the input column. + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get } - protected def addOutputColumn( - schema: StructType, - colName: String, - dataType: DataType): StructType = { - if (colName.length == 0) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") - val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) - StructType(outputFields) + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value + */ + protected final def setDefault[T](param: Param[T], value: T): this.type = { + shouldOwn(param) + defaultParamMap.put(param, value) + this + } + + /** + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. + */ + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Gets the default value of a parameter. + */ + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) + } + + /** + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + */ + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap + } + + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") } } @@ -253,12 +309,13 @@ private[spark] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -280,12 +337,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -293,10 +355,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -308,6 +367,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -317,7 +383,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) @@ -329,7 +395,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. @@ -355,7 +421,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Number of param pairs in this set. + * Number of param pairs in this map. */ def size: Int = map.size } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 0000000000000..95d7e64790c79 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name$defaultValueDoc. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + |$setDefault + | /** @group getParam */ + | final def get$Name: $T = getOrDefault($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 0000000000000..72b08bf276483 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = getOrDefault(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol (default: "features"). + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = getOrDefault(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol (default: "label"). + */ +@DeveloperApi +trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = getOrDefault(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol (default: "prediction"). + */ +@DeveloperApi +trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = getOrDefault(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol (default: "probability"). + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = getOrDefault(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + + /** @group getParam */ + final def getThreshold: Double = getOrDefault(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = getOrDefault(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCols. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = getOrDefault(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) +} + +/** + * :: DeveloperApi :: + * Trait for shared param fitIntercept (default: true). + */ +@DeveloperApi +trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = getOrDefault(fitIntercept) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index 5d660d1e151a7..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -/* NOTE TO DEVELOPERS: - * If you mix these parameter traits into your algorithm, please add a setter method as well - * so that users may use a builder pattern: - * val myLearner = new MyLearner().setParam1(x).setParam2(y)... - */ - -private[ml] trait HasRegParam extends Params { - /** - * param for regularization parameter - * @group param - */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - - /** @group getParam */ - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** - * param for max number of iterations - * @group param - */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - - /** @group getParam */ - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** - * param for features column name - * @group param - */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - - /** @group getParam */ - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** - * param for label column name - * @group param - */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - - /** @group getParam */ - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** - * param for prediction column name - * @group param - */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - - /** @group getParam */ - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasRawPredictionCol extends Params { - /** - * param for raw prediction column name - * @group param - */ - val rawPredictionCol: Param[String] = - new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", - Some("rawPrediction")) - - /** @group getParam */ - def getRawPredictionCol: String = get(rawPredictionCol) -} - -private[ml] trait HasProbabilityCol extends Params { - /** - * param for predicted class conditional probabilities column name - * @group param - */ - val probabilityCol: Param[String] = - new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", - Some("probability")) - - /** @group getParam */ - def getProbabilityCol: String = get(probabilityCol) -} - -private[ml] trait HasThreshold extends Params { - /** - * param for threshold in (binary) prediction - * @group param - */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - - /** @group getParam */ - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** - * param for input column name - * @group param - */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - - /** @group getParam */ - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasOutputCol extends Params { - /** - * param for output column name - * @group param - */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - - /** @group getParam */ - def getOutputCol: String = get(outputCol) -} - -private[ml] trait HasCheckpointInterval extends Params { - /** - * param for checkpoint interval - * @group param - */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") - - /** @group getParam */ - def getCheckpointInterval: Int = get(checkpointInterval) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 52c9e95d6012f..bd793beba35b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -54,86 +55,88 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for rank of the matrix factorization. * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ - def getRank: Int = get(rank) + def getRank: Int = getOrDefault(rank) /** * Param for number of user blocks. * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ - def getNumUserBlocks: Int = get(numUserBlocks) + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** * Param for number of item blocks. * @group param */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ - def getNumItemBlocks: Int = get(numItemBlocks) + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. * @group param */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = get(implicitPrefs) + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation. * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ - def getAlpha: Double = get(alpha) + def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = get(userCol) + def getUserCol: String = getOrDefault(userCol) /** * Param for the column name for item ids. * @group param */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) + val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = get(itemCol) + def getItemCol: String = getOrDefault(itemCol) /** * Param for the column name for ratings. * @group param */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = get(ratingCol) + def getRatingCol: String = getOrDefault(ratingCol) /** * Param for whether to apply nonnegativity constraints. * @group param */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - val getNonnegative: Boolean = get(nonnegative) + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -142,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -171,7 +174,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -283,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 65f6627a0c351..26ca7459c4fdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql.DataFrame @@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setRegParam(0.1) - setMaxIter(100) + setDefault(regParam -> 0.1, maxIter -> 100) /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2eb1dac56f1e9..4bb4ed813c006 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { + /** * param for the estimator to be cross-validated * @group param @@ -38,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = get(estimator) + def getEstimator: Estimator[_] = getOrDefault(estimator) /** * param for estimator param maps @@ -48,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) /** * param for the evaluator for selection @@ -57,17 +58,18 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = get(evaluator) + def getEvaluator: Evaluator = getOrDefault(evaluator) /** * param for number of folds for cross validation * @group param */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ - def getNumFolds: Int = get(numFolds) + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -92,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -130,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 0000000000000..0383bf0b382b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala similarity index 62% rename from core/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index 4636c4600a01a..ee933f4cfcafd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -15,15 +15,19 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD /** - * This class exists to restrict the visibility of TaskContext setters. + * A Wrapper of FPGrowthModel to provide helper method for Python */ -private [spark] object TaskContextHelper { - - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { - def unset(): Unit = TaskContext.unset() - + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6c386cacfb7ca..ab15f0f36a14b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -358,9 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable { val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data) } - - - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -420,6 +419,24 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -433,9 +450,9 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 180023922a9b0..aa53e88d59856 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,15 +17,20 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.{Logging, SparkException} +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} /** * :: Experimental :: @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} /** * :: Experimental :: @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] ( val v = powerIter(w, maxIterations) val assignments = kMeans(v, k).mapPartitions({ iter => iter.map { case (id, cluster) => - new Assignment(id, cluster) + Assignment(id, cluster) } }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging { * @param cluster assigned cluster id */ @Experimental - class Assignment(val id: Long, val cluster: Int) extends Serializable + case class Assignment(id: Long, cluster: Int) /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d1a174063caba..3fa5e068d16d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -87,6 +87,9 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a9c93e181e3ce..c02c79f094b66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging { validationInput: RDD[LabeledPoint], boostingStrategy: BoostingStrategy, validate: Boolean): GradientBoostedTreesModel = { - val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) + baseLearnerWeights(0) = firstTreeWeight + val startingModel = new GradientBoostedTreesModel( + Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1)) + + var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 + var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) + // pseudo-residual for second iteration + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + var m = 1 while (m < numIterations) { timer.start(s"building tree $m") @@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging { baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) + Regression, baseLearners.slice(0, m + 1), + baseLearnerWeights.slice(0, m + 1)) + + predError = GradientBoostedTreesModel.updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + logDebug("error of gbt = " + predError.values.mean()) if (validate) { // Stop training early if // 1. Reduction in error is less than the validationTol or // 2. If the error increases, that is if the model is overfit. // We want the model returned corresponding to the best validation error. - val currentValidateError = loss.computeError(partialModel, validationInput) + + validatePredError = GradientBoostedTreesModel.updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { return new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, @@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging { } } // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } m += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 793dd664c5d5a..6f570b4e09c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -37,14 +37,12 @@ object AbsoluteError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. * The gradient with respect to F(x) is: sign(F(x) - y) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + override def gradient(prediction: Double, label: Double): Double = { + if (label - prediction < 0) 1.0 else -1.0 } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 51b1aed167b66..24ee9f3d51293 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -39,15 +39,12 @@ object LogLoss extends Loss { * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - val prediction = model.predict(point.features) - - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) + override def gradient(prediction: Double, label: Double): Double = { + - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 357869ff6b333..d3b82b752fa0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -31,13 +31,11 @@ trait Loss extends Serializable { /** * Method to calculate the gradients for the gradient boosting calculation. - * @param model Model of the weak learner. - * @param point Instance of the training dataset. + * @param prediction Predicted feature + * @param label true label. * @return Loss gradient. */ - def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double + def gradient(prediction: Double, label: Double): Double /** * Method to calculate error of the base learner for the gradient boosting calculation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index b990707ca4525..58857ae15e93e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -37,14 +37,12 @@ object SquaredError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. * The gradient with respect to F(x) is: - 2 (y - F(x)) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - 2.0 * (model.predict(point.features) - point.label) + override def gradient(prediction: Double, label: Double): Double = { + 2.0 * (prediction - label) } override def computeError(prediction: Double, label: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 1950254b2aa6d..fef3d2acb202a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -130,35 +130,28 @@ class GradientBoostedTreesModel( val numIterations = trees.length val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) - var predictionAndError: RDD[(Double, Double)] = remappedData.map { i => - val pred = treeWeights(0) * trees(0).predict(i.features) - val error = loss.computeError(pred, i.label) - (pred, error) - } evaluationArray(0) = predictionAndError.values.mean() - // Avoid the model being copied across numIterations. val broadcastTrees = sc.broadcast(trees) - val broadcastWeights = sc.broadcast(treeWeights) - (1 until numIterations).map { nTree => predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => val currentTree = broadcastTrees.value(nTree) - val currentTreeWeight = broadcastWeights.value(nTree) - iter.map { - case (point, (pred, error)) => { - val newPred = pred + currentTree.predict(point.features) * currentTreeWeight - val newError = loss.computeError(newPred, point.label) - (newPred, newError) - } + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = pred + currentTree.predict(point.features) * currentTreeWeight + val newError = loss.computeError(newPred, point.label) + (newPred, newError) } } evaluationArray(nTree) = predictionAndError.values.mean() } broadcastTrees.unpersist() - broadcastWeights.unpersist() evaluationArray } @@ -166,6 +159,58 @@ class GradientBoostedTreesModel( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.predict(lp.features) + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { + case (lp, (pred, error)) => { + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + } + newPredError + } + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 0000000000000..161100134c92d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 3fb6e2ec46468..0dcfe5a2002dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite { intercept[NoSuchElementException] { group("abc") } - assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField())) + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField)) } test("attribute group without attributes") { @@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite { assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) - assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0f..35d8c2e16c6cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index a18c335952b96..9d09f24709e23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -private case class DataSet(features: Vector) class NormalizerSuite extends FunSuite with MLlibTestSparkContext { @@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { ) val sqlContext = new SQLContext(sc) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") @@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } } + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 0000000000000..00b5d094d82f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SQLContext + +class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index bf862b912d326..d186ead8f542f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -25,10 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} @BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { - /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ - def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) -} +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ @@ -46,14 +43,14 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), - TokenizerTestData("Te,st. punct", Seq("punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer.setMinTokenLength(3) @@ -64,8 +61,8 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setGaps(true) .setMinTokenLength(0) val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) )) testRegexTokenizer(tokenizer, dataset2) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala new file mode 100644 index 0000000000000..57d0278e03639 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("assemble") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + val dv = Vectors.dense(2.0, 0.0) + assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) + assert(assemble(0.0, dv, 1.0, sv) === + Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) + for (v <- Seq(1, "a", null)) { + intercept[SparkException](assemble(v)) + intercept[SparkException](assemble(1.0, v)) + } + } + + test("VectorAssembler") { + val df = sqlContext.createDataFrame(Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + assembler.transform(df).select("features").collect().foreach { + case Row(v: Vector) => + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 0000000000000..81ef831c42e55 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + // Chose correct categorical features + assert(categoryMaps.keys.toSet === categoricalFeatures) + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce2987612378..88ea679eeaad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,25 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) @@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { @@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a71..8f9ab687c05cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -20,17 +20,21 @@ package org.apache.spark.ml.param /** A subclass of Params for testing. */ class TestParams extends Params { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) + def getInputCol: String = getOrDefault(inputCol) - override def validate(paramMap: ParamMap) = { - val m = this.paramMap ++ paramMap + setDefault(maxIter -> 10) + + override def validate(paramMap: ParamMap): Unit = { + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 29d4ec5f85c1e..fc7349330cf86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 0000000000000..c44cb61b34171 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index f9fe3e006ccb8..ea89b17b7c08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -102,7 +102,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { def validateModelFit( piData: Array[Double], thetaData: Array[Array[Double]], - model: NaiveBayesModel) = { + model: NaiveBayesModel): Unit = { def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { (d1 - d2).abs <= precision } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index d50c43d439187..5683b55e8500a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 7bf250eb5a383..0f2b26d462ad2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -199,9 +199,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { test("k-means|| initialization") { case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { - @Override def compare(that: VectorWithCompare): Int = { - if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > - that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + override def compare(that: VectorWithCompare): Int = { + if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) { + -1 + } else { + 1 + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a94..15de10fd13a19 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -141,7 +141,7 @@ private[clustering] object LDASuite { (terms.toArray, termWeights.toArray) } - def tinyCorpus = Array( + def tinyCorpus: Array[(Long, Vector)] = Array( Vectors.dense(1, 3, 0, 2, 8), Vectors.dense(0, 2, 1, 0, 4), Vectors.dense(2, 3, 12, 3, 1), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6315c03a700f1..6d6fe6fe46bab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable +import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkContext import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext assert(x ~== u1(i.toInt) absTol 1e-14) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends FunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 850c9fce507cd..f90025d535e45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingKMeansSuite extends FunSuite with TestSuiteBase { - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for single center and equivalence to grand average") { // set parameters @@ -59,7 +59,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 0d2cec58e2c03..86119ec38101e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -439,4 +439,20 @@ class MatricesSuite extends FunSuite { assert(mUDT.typeName == "matrix") assert(mUDT.simpleString == "matrix") } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 6395188a0842a..63f2ea916d457 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -181,7 +181,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) - val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) + val exponential = + RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1) val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed) @@ -197,7 +198,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] { // This allows us to check that each partition has a different seed override def nextValue(): Double = seed.toDouble - override def setSeed(seed: Long) = this.seed = seed + override def setSeed(seed: Long): Unit = this.seed = seed override def copy(): MockDistro = new MockDistro } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 8775c0ca9df84..b3798940ddc38 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -203,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ + // scalastyle:off def testALS( users: Int, products: Int, @@ -216,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { numUserBlocks: Int = -1, numProductBlocks: Int = -1, negativeFactors: Boolean = true) { + // scalastyle:on + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 43d61151e2471..d6c93cc0e49cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -35,7 +35,7 @@ private object RidgeRegressionSuite { class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 24fd8df691817..26604dbe6c1ef 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index e957fa5d25f4c..352193a67860c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -95,16 +95,16 @@ object TestingUtils { /** * Comparison using absolute tolerance. */ - def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, - x, eps, ABS_TOL_MSG) + def absTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) /** * Comparison using relative tolerance. */ - def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, - x, eps, REL_TOL_MSG) + def relTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareVectorRightSide( @@ -166,7 +166,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareMatrixRightSide( @@ -229,7 +229,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index b0ecb33c28483..59e6c778806f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite { assert(!(17.8 ~= 17.59 absTol 0.2)) // Comparisons of numbers very close to zero, and both side of zeros - assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - - assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert( + -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) } test("Comparing vectors using relative error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) @@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite { test("Comparing vectors using absolute error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 73da9b7346f4d..b6fbace509a0e 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -21,9 +21,13 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.google.common.base.Charsets; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -121,4 +125,66 @@ private static boolean isSymlink(File file) throws IOException { } return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } + + private static ImmutableMap timeSuffixes = + ImmutableMap.builder() + .put("us", TimeUnit.MICROSECONDS) + .put("ms", TimeUnit.MILLISECONDS) + .put("s", TimeUnit.SECONDS) + .put("m", TimeUnit.MINUTES) + .put("min", TimeUnit.MINUTES) + .put("h", TimeUnit.HOURS) + .put("d", TimeUnit.DAYS) + .build(); + + /** + * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for + * internal use. If no suffix is provided a direct conversion is attempted. + */ + private static long parseTimeString(String str, TimeUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + String suffix; + long val; + Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); + if (m.matches()) { + val = Long.parseLong(m.group(1)); + suffix = m.group(2); + } else { + throw new NumberFormatException("Failed to parse time string: " + str); + } + + // Check for invalid suffixes + if (suffix != null && !timeSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); + } catch (NumberFormatException e) { + String timeError = "Time must be specified as seconds (s), " + + "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "E.g. 50s, 100ms, or 250us."; + + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + public static long timeStringAsMs(String str) { + return parseTimeString(str, TimeUnit.MILLISECONDS); + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + public static long timeStringAsSec(String str) { + return parseTimeString(str, TimeUnit.SECONDS); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 2eaf3b71d9a49..0aef7f1987315 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -37,8 +37,11 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - int defaultTimeout = conf.getInt("spark.network.timeout", 120); - return conf.getInt("spark.shuffle.io.connectionTimeout", defaultTimeout) * 1000; + long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( + conf.get("spark.network.timeout", "120s")); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ @@ -68,7 +71,9 @@ public int numConnectionsPerPeer() { public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { return conf.getInt("spark.shuffle.sasl.timeout", 30) * 1000; } + public int saslRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. @@ -80,7 +85,9 @@ public int numConnectionsPerPeer() { * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ - public int ioRetryWaitTimeMs() { return conf.getInt("spark.shuffle.io.retryWait", 5) * 1000; } + public int ioRetryWaitTimeMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + } /** * Minimum size of a block that we should start using memory map rather than reading in through diff --git a/pom.xml b/pom.xml index 42bd926a2fcb8..d8881c213bf07 100644 --- a/pom.xml +++ b/pom.xml @@ -159,6 +159,8 @@ 1.1.1.6 1.1.2 + ${java.home} + ${test_classpath} + ${test.java.home} true @@ -1224,6 +1227,7 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + ${test.java.home} true @@ -1716,6 +1720,16 @@ + + test-java-home + + env.JAVA_HOME + + + ${env.JAVA_HOME} + + + scala-2.11 @@ -1749,5 +1763,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c2d828f982fe0..1564babefa62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,6 +64,10 @@ object MimaExcludes { // SPARK-6492 Fix deadlock in SparkContext.stop() ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d3faa551a4b14..5f51f4b58f97a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -119,7 +119,9 @@ object SparkBuild extends PomBuild { lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( - javaHome := Properties.envOrNone("JAVA_HOME").map(file), + javaHome := sys.env.get("JAVA_HOME") + .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) + .map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", @@ -426,8 +428,10 @@ object TestSettings { fork := true, // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes // launched by the tests have access to the correct test-time classpath. - envVars in Test += ("SPARK_DIST_CLASSPATH" -> - (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")), + envVars in Test ++= Map( + "SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905e..7096b0d3ee7de 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da415..471d00bd8223f 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala deleted file mode 100644 index 3d43c35299555..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} - -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit - -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList - } - - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() - -} diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 15101470afc07..26ece4c2c389a 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -31,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0a16cbd8bff62..2a5e84a7dfdb4 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,11 +29,10 @@ def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/join.py b/python/pyspark/join.py index efc1ef9396412..c3491defb2b29 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -48,7 +48,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +62,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +76,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -104,8 +104,9 @@ def make_mapper(i): rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4ff7463498cce..7f42de531f3b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -91,9 +91,9 @@ class LogisticRegressionModel(JavaModel): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 433b4fb5d22bf..1cfcd019dfb18 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -117,9 +117,9 @@ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 6449800d9c120..f2ef573fe9f6f 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -25,7 +25,7 @@ if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 3cda1205e1391..8be819aceec24 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -132,6 +132,22 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 0000000000000..3aa6d79d7093c --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + """ + + def freqItemsets(self): + """ + Get the frequent itemsets of this model + """ + return self.call("getFreqItemsets") + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + :param data: The input data set, each element + contains a transaction. + :param minSupport: The minimal support level + (default: `0.3`). + :param numPartitions: The number of partitions used by parallel + FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51c1490b1618d..a80320c52d1d0 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -640,6 +640,15 @@ def toArray(self): """ raise NotImplementedError + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, basestring): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + class DenseMatrix(Matrix): """ @@ -647,13 +656,8 @@ class DenseMatrix(Matrix): """ def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) - if isinstance(values, basestring): - values = np.frombuffer(values, dtype=np.float64) - elif not isinstance(values, np.ndarray): - values = np.array(values, dtype=np.float64) + values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols - if values.dtype != np.float64: - values.astype(np.float64) self.values = values def __reduce__(self): @@ -670,6 +674,17 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def toSparse(self): + """Convert to SparseMatrix""" + indices = np.nonzero(self.values)[0] + colCounts = np.bincount(indices / self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = self.values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: @@ -687,6 +702,82 @@ def __eq__(self, other): all(self.values == other.values)) +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols) + self.isTransposed = isTransposed + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + self.isTransposed) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + + def toDense(self): + densevals = np.reshape( + self.toArray(), (self.numRows * self.numCols), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: + def __eq__(self, other): + return np.all(self.toArray == other.toArray) + + class Matrices(object): @staticmethod def dense(numRows, numCols, values): @@ -695,6 +786,13 @@ def dense(numRows, numCols, values): """ return DenseMatrix(numRows, numCols, values) + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + def _test(): import doctest diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 47dad7d12e4e4..8eaddcf8b9b5e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -24,7 +24,7 @@ import tempfile import array as pyarray -from numpy import array, array_equal +from numpy import array, array_equal, zeros from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -38,12 +38,13 @@ from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -144,6 +145,54 @@ def test_matrix_indexing(self): for j in range(2): self.assertEquals(mat[i, j], expected[i][j]) + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEquals(sm1.numRows, 3) + self.assertEquals(sm1.numCols, 4) + self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEquals(sm1.numRows, smnew.numRows) + self.assertEquals(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEquals(sm1t.numRows, 3) + self.assertEquals(sm1t.numCols, 4) + self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + class ListTests(PySparkTestCase): @@ -363,6 +412,13 @@ def test_col_norms(self): self.assertEqual(10, len(summary.normL1())) self.assertEqual(10, len(summary.normL2())) + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + class VectorUDTTests(PySparkTestCase): @@ -690,6 +746,29 @@ def test_word2vec_get_vectors(self): model = Word2Vec().fit(self.sc.parallelize(data)) self.assertEquals(len(model.getVectors()), 3) + +class StandardScalerTests(PySparkTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2d05611321ed6..c9ac95d117574 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -573,8 +573,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): @@ -595,7 +595,7 @@ def sortPartition(iterator): maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary @@ -1699,10 +1699,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1755,21 +1753,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> sorted(x.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(x.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1781,8 +1786,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ @@ -2209,7 +2233,7 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) + pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6c..1ab5ce14c3531 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ffb41d02f6f6..4afa82f4b2973 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,6 +220,29 @@ def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in xrange(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%d)" % self.batchSize + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -251,7 +274,7 @@ def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and other.serializer == self.serializer and other.bestSize == self.bestSize) - def __str__(self): + def __repr__(self): return "AutoBatchedSerializer(%s)" % str(self.serializer) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 1a02fece9c5a5..81aa970a32f76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -53,9 +53,9 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = HiveContext(sc) + sqlCtx = sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = SQLContext(sc) + sqlCtx = sqlContext = SQLContext(sc) print("""Welcome to ____ __ @@ -68,7 +68,7 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__) +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd502000..8a6fc627eb383 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -148,10 +157,16 @@ def mergeCombiners(self, iterator): d[k] = comb(d[k], v) if k in d else v def iteritems(self): - """ Return the merged items ad iterator """ + """ Return the merged items as iterator """ return self.data.iteritems() +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) + + class ExternalMerger(Merger): """ @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -335,7 +340,7 @@ def _spill(self): for k, v in self.data.iteritems(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,7 +349,7 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,56 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) - def _recursive_merged_items(self, start): + return self.data.iteritems() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,6 +448,7 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random >>> l = range(1024) @@ -469,7 +461,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -515,6 +507,7 @@ def sorted(self, iterator, key=None, reverse=False): limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: batch = min(batch * 2, 10000) @@ -529,6 +522,310 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(range(100)) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + f = os.fdopen(os.dup(self._file.fileno())) + f.seek(0) + serialized = f.read() + else: + serialized = '' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id)) + self._file = open(p, "w+", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i/3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(iter(zip(k, v))) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iter(iterator) + self.next_item = None + + def __iter__(self): + return self + + def next(self): + key, value = self.next_item if self.next_item else next(self.iterator) + values = ExternalListOfList([value]) + try: + while True: + k, v = next(self.iterator) + if k != key: + self.next_item = (k, v) + break + values.append(v) + except StopIteration: + self.next_item = None + return key, values + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.iteritems(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if (size >> 20) >= self.memory_limit / 10: + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + return self.data.iteritems() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + return self.serializer.load_stream(open(p, 'r', 65536)) + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c2d81ba804110..e8529a8f8e3a4 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -37,12 +37,12 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlCtx): +def _monkey_patch_RDD(sqlContext): def toDF(self, schema=None, sampleRatio=None): """ Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)`` + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -51,7 +51,7 @@ def toDF(self, schema=None, sampleRatio=None): >>> rdd.toDF().collect() [Row(name=u'Alice', age=1)] """ - return sqlCtx.createDataFrame(self, schema, sampleRatio) + return sqlContext.createDataFrame(self, schema, sampleRatio) RDD.toDF = toDF @@ -75,13 +75,13 @@ def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime - >>> sqlCtx = SQLContext(sc) + >>> sqlContext = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, @@ -133,18 +133,18 @@ def registerFunction(self, name, f, returnType=StringType()): :param samplingRatio: lambda function :param returnType: a :class:`DataType` object - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) @@ -229,26 +229,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :param samplingRatio: the sample ratio of rows used for inferring >>> l = [('Alice', 1)] - >>> sqlCtx.createDataFrame(l).collect() + >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] - >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] - >>> sqlCtx.createDataFrame(d).collect() + >>> sqlContext.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) - >>> sqlCtx.createDataFrame(rdd).collect() + >>> sqlContext.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] - >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) - >>> df2 = sqlCtx.createDataFrame(person) + >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] @@ -256,11 +256,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) - >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] - >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] """ if isinstance(data, DataFrame): @@ -316,7 +316,7 @@ def registerDataFrameAsTable(self, df, tableName): Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") """ if (df.__class__ is DataFrame): self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) @@ -330,7 +330,7 @@ def parquetFile(self, *paths): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -352,7 +352,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> shutil.rmtree(jsonFile) >>> with open(jsonFile, 'w') as f: ... f.writelines(jsonStrings) - >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> df1 = sqlContext.jsonFile(jsonFile) >>> df1.printSchema() root |-- field1: long (nullable = true) @@ -365,7 +365,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructField("field2", StringType()), ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) >>> df2.printSchema() root |-- field2: string (nullable = true) @@ -386,11 +386,11 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - >>> df1 = sqlCtx.jsonRDD(json) + >>> df1 = sqlContext.jsonRDD(json) >>> df1.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2 = sqlContext.jsonRDD(json, df1.schema) >>> df2.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) @@ -400,7 +400,7 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))])) ... ]) - >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3 = sqlContext.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) """ @@ -480,8 +480,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ @@ -490,8 +490,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -505,8 +505,8 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.tables() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() Row(tableName=u'table1', isTemporary=True) """ @@ -520,10 +520,10 @@ def tableNames(self, dbName=None): If ``dbName`` is not specified, the current database will be used. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> "table1" in sqlCtx.tableNames() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlCtx.tableNames("db") + >>> "table1" in sqlContext.tableNames("db") True """ if dbName is None: @@ -574,15 +574,24 @@ def _ssql_ctx(self): def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) + def refreshTable(self, tableName): + """Invalidate and refresh all the cached the metadata of the given + table. For performance reasons, Spark SQL or the external data source + library it uses might cache certain metadata about a table, such as the + location of blocks. When those change outside of Spark SQL, users should + call this function to invalidate the cache. + """ + self._ssql_ctx.refreshTable(tableName) + class UDFRegistration(object): """Wrapper for user-defined function registration.""" - def __init__(self, sqlCtx): - self.sqlCtx = sqlCtx + def __init__(self, sqlContext): + self.sqlContext = sqlContext def register(self, name, f, returnType=StringType()): - return self.sqlCtx.registerFunction(name, f, returnType) + return self.sqlContext.registerFunction(name, f, returnType) register.__doc__ = SQLContext.registerFunction.__doc__ @@ -595,13 +604,12 @@ def _test(): globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) - _monkey_patch_RDD(sqlCtx) globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c30326ebd133e..ef91a9c4f522d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -110,7 +110,7 @@ def saveAsParquetFile(self, path): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) True """ @@ -123,7 +123,7 @@ def registerTempTable(self, name): that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") + >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -1180,7 +1180,7 @@ def _test(): globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 146ba6f3e0d98..daeb6916b58bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -161,7 +161,7 @@ def _test(): globs = pyspark.sql.functions.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 45eb8b945dcb0..ef76d84c00481 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -434,7 +434,7 @@ def _parse_datatype_json_string(json_string): >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_primitive_types.values(): @@ -567,8 +567,8 @@ def _infer_schema(row): elif isinstance(row, (tuple, list)): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) + elif hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) @@ -647,7 +647,7 @@ def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields") or hasattr(obj, "__fields__"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) @@ -997,12 +997,13 @@ def _restore_object(dataType, obj): # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) - if cls is None: + if cls is None or cls.__datatype is not dataType: # use dataType as key to avoid create multiple class cls = _cached_cls.get(dataType) if cls is None: cls = _create_cls(dataType) _cached_cls[dataType] = cls + cls.__datatype = dataType _cached_cls[k] = cls return cls(obj) @@ -1119,8 +1120,8 @@ def Dict(d): class Row(tuple): """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) + __datatype = dataType + __fields__ = tuple(f.name for f in dataType.fields) __slots__ = () # create property for fast access @@ -1128,22 +1129,22 @@ class Row(tuple): def asDict(self): """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) + return dict((n, getattr(self, n)) for n in self.__fields__) def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__fields__)) def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) + return (_restore_object, (self.__datatype, tuple(self))) return Row def _create_row(fields, values): row = Row(*values) - row.__FIELDS__ = fields + row.__fields__ = fields return row @@ -1183,7 +1184,7 @@ def __new__(self, *args, **kwargs): # create row objects names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) - row.__FIELDS__ = names + row.__fields__ = names return row else: @@ -1193,11 +1194,11 @@ def asDict(self): """ Return as an dict """ - if not hasattr(self, "__FIELDS__"): + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) + return dict(zip(self.__fields__, self)) - # let obect acs like class + # let object acts like class def __call__(self, *args): """create new Row object""" return _create_row(self, args) @@ -1208,21 +1209,21 @@ def __getattr__(self, item): try: # it will be slow when it has many fields, # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) + idx = self.__fields__.index(item) return self[idx] except IndexError: raise AttributeError(item) def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) + if hasattr(self, "__fields__"): + return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self) def __repr__(self): - if hasattr(self, "__FIELDS__"): + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__fields__, tuple(self))) else: return "" % ", ".join(self) @@ -1237,7 +1238,7 @@ def _test(): globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 608f8e26473a6..9b4635e49020b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -23,13 +23,16 @@ import tempfile import struct +from py4j.java_collections import MapConverter + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext +from pyspark.streaming.kafka import KafkaUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 20 # seconds duration = 1 def setUp(self): @@ -556,5 +559,43 @@ def check_output(n): check_output(3) +class KafkaStreamTests(PySparkStreamingTestCase): + + def setUp(self): + super(KafkaStreamTests, self).setUp() + + kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") + self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils.setup() + + def tearDown(self): + if self._kafkaTestUtils is not None: + self._kafkaTestUtils.teardown() + self._kafkaTestUtils = None + + super(KafkaStreamTests, self).tearDown() + + def test_kafka_stream(self): + """Test the Python Kafka stream API.""" + topic = "topic1" + sendData = {"a": 3, "b": 5, "c": 10} + jSendData = MapConverter().convert(sendData, + self.ssc.sparkContext._gateway._gateway_client) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, jSendData) + + stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), + "test-streaming-consumer", {topic: 1}, + {"auto.offset.reset": "smallest"}) + + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dd8d3b1c53733..b938b9ce12395 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,9 +31,12 @@ import time import zipfile import random +import itertools import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -76,7 +79,7 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -108,7 +111,7 @@ def test_small_dataset(self): sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 30) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), @@ -124,10 +127,36 @@ def test_huge_dataset(self): m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.iteritems()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -702,6 +731,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "5m") + N = 200001 + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda (k, vs): k == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) + self.assertEqual([(N/3, N/3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N/3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -752,9 +796,9 @@ def test_narrow_dependency_in_join(self): self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - self.sc.setJobGroup("test1", "test", True) tracker = self.sc.statusTracker() + self.sc.setJobGroup("test1", "test", True) d = sorted(parted.join(parted).collect()) self.assertEqual(10, len(d)) self.assertEqual((0, (0, 0)), d[0]) @@ -787,6 +831,17 @@ def test_take_on_jrdd(self): rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) rdd._jrdd.first() + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + class ProfilerTests(PySparkTestCase): @@ -1441,6 +1496,20 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = sys.version_info + sys.version_info = (2, 0, 0) + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + try: + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + sys.version_info = version + log4j.LogManager.getRootLogger().setLevel(old_level) + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8a93c320ec5d3..452d6fabdcc17 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -88,7 +88,11 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer) = command + (func, profiler, deserializer, serializer), version = command + if version != sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + (sys.version_info[:2], version)) init_time = time.time() def process(): diff --git a/python/run-tests b/python/run-tests index b7630c356cfae..f3a07d8aba562 100755 --- a/python/run-tests +++ b/python/run-tests @@ -21,6 +21,8 @@ # Figure out where the Spark framework is installed FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +. "$FWDIR"/bin/load-spark-env.sh + # CD into the python directory to find things on the right path cd "$FWDIR/python" @@ -57,7 +59,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -77,6 +79,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" @@ -96,6 +99,21 @@ function run_ml_tests() { function run_streaming_tests() { echo "Run streaming tests ..." + + KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly + JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" + for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 + echo "You need to build Spark with " \ + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ + "'build/mvn package' before running this program" 1>&2 + exit 1 + fi + KAFKA_ASSEMBLY_JAR="$f" + done + + export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e7e4a4113174a..e2ee9c963a4da 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 92e76a3fe6ca2..d8e0facb81169 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " # if no args specified, show usage if [ $# -le 1 ]; then @@ -195,6 +195,23 @@ case $option in fi ;; + (status) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo $command is running. + exit 0 + else + echo $pid file is present but $command not running + exit 1 + fi + else + echo $command not running. + exit 2 + fi + ;; + (*) echo $usage exit 1 diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 5a6de11afdd3d..4c919ff76a8f5 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -18,15 +18,68 @@ # # Starts a slave on the machine this script is executed on. +# +# Environment Variables +# +# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# slave. Default is 1. +# SPARK_WORKER_PORT The base port number for the first worker. If set, +# subsequent workers will increment this number. If +# unset, Spark will find a valid port number, but +# with no guarantee of a predictable pattern. +# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first +# worker. Subsequent workers will increment this +# number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" +usage="Usage: start-slave.sh where is like spark://localhost:7077" -if [ $# -lt 2 ]; then +if [ $# -lt 1 ]; then echo $usage + echo Called as start-slave.sh $* exit 1 fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +# First argument should be the master; we need to store it aside because we may +# need to insert arguments between it and the other arguments +MASTER=$1 +shift + +# Determine desired worker port +if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then + SPARK_WORKER_WEBUI_PORT=8081 +fi + +# Start up the appropriate number of workers on this machine. +# quick local function to start a worker +function start_instance { + WORKER_NUM=$1 + shift + + if [ "$SPARK_WORKER_PORT" = "" ]; then + PORT_FLAG= + PORT_NUM= + else + PORT_FLAG="--port" + PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) + fi + WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) + + "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" +} + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + start_instance 1 "$@" +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + start_instance $(( 1 + $i )) "$@" + done +fi + diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 4356c03657109..24d6268815ed3 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -59,13 +59,4 @@ if [ "$START_TACHYON" == "true" ]; then fi # Launch the slaves -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" -else - if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then - SPARK_WORKER_WEBUI_PORT=8081 - fi - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh new file mode 100755 index 0000000000000..3d1da5b254f2a --- /dev/null +++ b/sbin/stop-slave.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to stop all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: stop-slave.sh +# Stops all slaves on this worker machine + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 7c2201100ef97..54c9bd46803a9 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -29,10 +29,4 @@ if [ -e "$sbin"/../tachyon/bin/tachyon ]; then "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh diff --git a/sql/README.md b/sql/README.md index fbb3200a3a4b4..237620e3fa808 100644 --- a/sql/README.md +++ b/sql/README.md @@ -56,6 +56,6 @@ res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,v You can also build further queries on top of these `DataFrames` using the query DSL. ``` -scala> query.where('key > 30).select(avg('key)).collect() +scala> query.where(query("key") > 30).select(avg(query("key"))).collect() res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala new file mode 100644 index 0000000000000..91976fef6dc0d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.util.{Map => JavaMap} + +import scala.collection.mutable.HashMap + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Functions to convert Scala types to Catalyst types and vice versa. + */ +object CatalystTypeConverters { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => + udt.serialize(obj) + + case (o: Option[_], _) => + o.map(convertToCatalyst(_, dataType)).orNull + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToCatalyst(_, arrayType.elementType)) + + case (s: Array[_], arrayType: ArrayType) => + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + + case (jmap: JavaMap[_, _], mapType: MapType) => + val iter = jmap.entrySet.iterator + var listOfEntries: List[(Any, Any)] = List() + while (iter.hasNext) { + val entry = iter.next() + listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), + convertToCatalyst(entry.getValue, mapType.valueType)) + } + listOfEntries.toMap + + case (p: Product, structType: StructType) => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case (d: BigDecimal, _) => + Decimal(d) + + case (d: java.math.BigDecimal, _) => + Decimal(d) + + case (d: java.sql.Date, _) => + DateUtils.fromJavaDate(d) + + case (r: Row, structType: StructType) => + val converters = structType.fields.map { + f => (item: Any) => convertToCatalyst(item, f.dataType) + } + convertRowWithConverters(r, structType, converters) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def extractOption(item: Any): Any = item match { + case opt: Option[_] => opt.orNull + case other => other + } + + dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item) => extractOption(item) match { + case null => null + case other => udt.serialize(other) + } + + case arrayType: ArrayType => + val elementConverter = createToCatalystConverter(arrayType.elementType) + (item: Any) => { + extractOption(item) match { + case a: Array[_] => a.toSeq.map(elementConverter) + case s: Seq[_] => s.map(elementConverter) + case null => null + } + } + + case mapType: MapType => + val keyConverter = createToCatalystConverter(mapType.keyType) + val valueConverter = createToCatalystConverter(mapType.valueType) + (item: Any) => { + extractOption(item) match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) + } + convertedMap + + case null => null + } + } + + case structType: StructType => + val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) + (item: Any) => { + extractOption(item) match { + case r: Row => + convertRowWithConverters(r, structType, converters) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx)(iter.next()) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case null => + null + } + } + + case dateType: DateType => (item: Any) => extractOption(item) match { + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case other => other + } + + case _ => + (item: Any) => extractOption(item) match { + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case other => other + } + } + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => + udt.deserialize(d) + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToScala(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + + case (r: Row, s: StructType) => + convertRowToScala(r, s) + + case (d: Decimal, _: DecimalType) => + d.toJavaBigDecimal + + case (i: Int, DateType) => + DateUtils.toJavaDate(i) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item: Any) => if (item == null) null else udt.deserialize(item) + + case arrayType: ArrayType => + val elementConverter = createToScalaConverter(arrayType.elementType) + (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) + + case mapType: MapType => + val keyConverter = createToScalaConverter(mapType.keyType) + val valueConverter = createToScalaConverter(mapType.valueType) + (item: Any) => if (item == null) { + null + } else { + item.asInstanceOf[Map[_, _]].map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + } + + case s: StructType => + val converters = s.fields.map(f => createToScalaConverter(f.dataType)) + (item: Any) => { + if (item == null) { + null + } else { + convertRowWithConverters(item.asInstanceOf[Row], s, converters) + } + } + + case _: DecimalType => + (item: Any) => item match { + case d: Decimal => d.toJavaBigDecimal + case other => other + } + + case DateType => + (item: Any) => item match { + case i: Int => DateUtils.toJavaDate(i) + case other => other + } + + case other => + (item: Any) => item + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + val ar = new Array[Any](r.size) + var idx = 0 + while (idx < r.size) { + ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } + + /** + * Converts a row by applying the provided set of converter functions. It is used for both + * toScala and toCatalyst conversions. + */ + private[sql] def convertRowWithConverters( + row: Row, + schema: StructType, + converters: Array[Any => Any]): Row = { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx)(row(idx)) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8bfd0471d9c7a..01d5c1512201a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -46,61 +46,6 @@ trait ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** - * Converts Scala objects to catalyst rows / types. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) - case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) { - s.toSeq - } else { - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) - } - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } - case (p: Product, structType: StructType) => - new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (d: BigDecimal, _) => Decimal(d) - case (d: java.math.BigDecimal, _) => Decimal(d) - case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) - case (r: Row, structType: StructType) => - new GenericRow( - r.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (other, _) => other - } - - /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => udt.deserialize(d) - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - case (r: Row, s: StructType) => convertRowToScala(r, s) - case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal - case (i: Int, DateType) => DateUtils.toJavaDate(i) - case (other, _) => other - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - // TODO: This is very slow!!! - new GenericRowWithSchema( - r.toSeq.zip(schema.fields.map(_.dataType)) - .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema) - } - /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 89f4a19add1c6..bc8d3751f6616 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -111,6 +111,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") + protected val WITH = Keyword("WITH") protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -127,6 +128,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) | insert + | cte ) protected lazy val select: Parser[LogicalPlan] = @@ -153,7 +155,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val insert: Parser[LogicalPlan] = INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o) + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) + } + + protected lazy val cte: Parser[LogicalPlan] = + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start <~ ")"), ",") ~ start ^^ { + case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) } protected lazy val projection: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 119cb9c3a4400..8b68b0df35f48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyser needs only to resolve attribute + * when all relations are already filled in and the analyzer needs only to resolve attribute * references. */ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) @@ -169,21 +169,36 @@ class Analyzer( * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation): LogicalPlan = { + def getTable(u: UnresolvedRelation, cteRelations: Map[String, LogicalPlan]): LogicalPlan = { try { - catalog.lookupRelation(u.tableIdentifier, u.alias) + // In hive, if there is same table name in database and CTE definition, + // hive will use the table in database, not the CTE one. + // Taking into account the reasonableness and the implementation complexity, + // here use the CTE definition first, check table name only and ignore database name + cteRelations.get(u.tableIdentifier.last) + .map(relation => u.alias.map(Subquery(_, relation)).getOrElse(relation)) + .getOrElse(catalog.lookupRelation(u.tableIdentifier, u.alias)) } catch { case _: NoSuchTableException => u.failAnalysis(s"no such table ${u.tableName}") } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) => - i.copy( - table = EliminateSubQueries(getTable(u))) - case u: UnresolvedRelation => - getTable(u) + def apply(plan: LogicalPlan): LogicalPlan = { + val (realPlan, cteRelations) = plan match { + // TODO allow subquery to define CTE + // Add cte table to a temp relation map,drop `with` plan and keep its child + case With(child, relations) => (child, relations) + case other => (other, Map.empty[String, LogicalPlan]) + } + + realPlan transform { + case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + i.copy( + table = EliminateSubQueries(getTable(u, cteRelations))) + case u: UnresolvedRelation => + getTable(u, cteRelations) + } } } @@ -293,7 +308,7 @@ class Analyzer( logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => - resolveGetField(child, fieldName) + GetField(child, fieldName, resolver) } } @@ -313,36 +328,6 @@ class Analyzer( */ protected def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - */ - protected def resolveGetField(expr: Expression, fieldName: String): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5eb7dff0cede8..b2f8157a1a61f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** - * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception + * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception * as an AnalysisException with the correct position information. */ class NoSuchTableException extends Exception @@ -201,7 +201,7 @@ trait OverrideCatalog extends Catalog { /** * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyser needs only to resolve attribute references. + * relations are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c43ea55899695..16ca5bcd57a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -57,8 +57,8 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr } /** - * A trivial catalog that returns an error when a function is requested. Used for testing when all - * functions are already filled in and the analyser needs only to resolve attribute references. + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { override def registerFunction(name: String, builder: FunctionBuilder): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 145f062dd6817..21c15ad14fd19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -293,7 +293,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 389dc4f745723..9a77ca624ebe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType /** @@ -39,12 +39,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) - s""" case $x => + s"""case $x => val func = function.asInstanceOf[($anys) => Any] $childs + $converters (input: Row) => { func( $evals) @@ -60,51 +62,61 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (input: Row) => { func() } - + case 1 => val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + converter0(child0.eval(input))) } - + case 2 => val func = function.asInstanceOf[(Any, Any) => Any] val child0 = children(0) val child1 = children(1) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input))) } - + case 3 => val func = function.asInstanceOf[(Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input))) } - + case 4 => val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) val child3 = children(3) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input))) } - + case 5 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -112,15 +124,20 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child2 = children(2) val child3 = children(3) val child4 = children(4) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input))) } - + case 6 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -129,16 +146,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child3 = children(3) val child4 = children(4) val child5 = children(5) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input))) } - + case 7 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -148,17 +171,24 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child4 = children(4) val child5 = children(5) val child6 = children(6) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input))) } - + case 8 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -169,18 +199,26 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child5 = children(5) val child6 = children(6) val child7 = children(7) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input))) } - + case 9 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -192,19 +230,28 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child6 = children(6) val child7 = children(7) val child8 = children(8) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input))) } - + case 10 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -217,20 +264,30 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child7 = children(7) val child8 = children(8) val child9 = children(9) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input))) } - + case 11 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -244,21 +301,32 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child8 = children(8) val child9 = children(9) val child10 = children(10) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input))) } - + case 12 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -273,22 +341,34 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child9 = children(9) val child10 = children(10) val child11 = children(11) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input))) } - + case 13 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -304,23 +384,36 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child10 = children(10) val child11 = children(11) val child12 = children(12) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input))) } - + case 14 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -337,24 +430,38 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child11 = children(11) val child12 = children(12) val child13 = children(13) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input))) } - + case 15 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -372,25 +479,40 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child12 = children(12) val child13 = children(13) val child14 = children(14) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input))) } - + case 16 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -409,26 +531,42 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child13 = children(13) val child14 = children(14) val child15 = children(15) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input))) } - + case 17 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -448,27 +586,44 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child14 = children(14) val child15 = children(15) val child16 = children(16) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input))) } - + case 18 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -489,28 +644,46 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child15 = children(15) val child16 = children(16) val child17 = children(17) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input))) } - + case 19 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -532,29 +705,48 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child16 = children(16) val child17 = children(17) val child18 = children(18) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input))) } - + case 20 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -577,30 +769,50 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child17 = children(17) val child18 = children(18) val child19 = children(19) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input))) } - + case 21 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -624,31 +836,52 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child18 = children(18) val child19 = children(19) val child20 = children(20) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input))) } - + case 22 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -673,35 +906,57 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child19 = children(19) val child20 = children(20) val child21 = children(21) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType), - ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input)), + converter21(child21.eval(input))) } } - + // scalastyle:on - - override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + + override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 406de38d1c483..14a855054b94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress override def children: Seq[Expression] = expressions override def nullable: Boolean = false - override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) + override def newInstance(): CollectHashSetFunction = + new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction( override def eval(input: Row): Any = seen.size.toLong } +/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ +private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { + + override def sqlType: DataType = BinaryType + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def serialize(obj: Any): Array[Byte] = + obj.asInstanceOf[HyperLogLog].getBytes + + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def deserialize(datum: Any): HyperLogLog = + HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) + + override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] +} + case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { override def nullable: Boolean = false - override def dataType: DataType = child.dataType + override def dataType: DataType = HyperLogLogUDT override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def newInstance(): ApproxCountDistinctPartitionFunction = { new ApproxCountDistinctPartitionFunction(child, this, relativeSD) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1f6526ef66c56..566b34f7c3a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } +case class MinOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable + + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 + } else { + if (ordering.compare(evalE1, evalE2) < 0) { + evalE1 + } else { + evalE2 + } + } + } + + override def toString: String = s"MinOf($left, $right)" +} + /** * A function that get the absolute value of the numeric value. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1abf3c0b64a5..d141354a0f427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -464,7 +464,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val itemEval = expressionEvaluator(item) val setEval = expressionEvaluator(set) - val ArrayType(elementType, _) = set.dataType + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType itemEval.code ++ setEval.code ++ q""" @@ -482,7 +482,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val leftEval = expressionEvaluator(left) val rightEval = expressionEvaluator(right) - val ArrayType(elementType, _) = left.dataType + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType leftEval.code ++ rightEval.code ++ q""" @@ -524,6 +524,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + case UnscaledValue(child) => val childEval = expressionEvaluator(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3b2b9211268a9..fc1f69655963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types._ /** @@ -81,6 +83,41 @@ trait GetField extends UnaryExpression { def field: StructField } +object GetField { + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + def apply( + expr: Expression, + fieldName: String, + resolver: Resolver): GetField = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } +} + /** * Returns the value of fields in the Struct `child`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 35faa00782e80..4c44182278207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +/** The data type for expressions returning an OpenHashSet as the result. */ +private[sql] class OpenHashSetUDT( + val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { + + override def sqlType: DataType = ArrayType(elementType) + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def serialize(obj: Any): Seq[Any] = { + obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq + } + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def deserialize(datum: Any): OpenHashSet[Any] = { + val iterator = datum.asInstanceOf[Seq[Any]].iterator + val set = new OpenHashSet[Any] + while(iterator.hasNext) { + set.add(iterator.next()) + } + + set + } + + override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] + + private[spark] override def asNullable: OpenHashSetUDT = this +} + /** * Creates a new set of the specified type */ @@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def nullable: Boolean = false - // We are currently only using these Expressions internally for aggregation. However, if we ever - // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - override def dataType: DataType = ArrayType(elementType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) override def eval(input: Row): Any = { new OpenHashSet[Any]() @@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: DataType = set.dataType + override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] override def eval(input: Row): Any = { val itemEval = item.eval(input) @@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType + override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] override def symbol: String = "++=" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 02f7c26a8ab6e..7967189cacb24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index bb79dc340553b..e3e070f0ff307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField} +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) @@ -31,7 +31,8 @@ object LocalRelation { def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) - LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema))) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2e9f3aa4ec4ad..579a0fb8d3f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -205,13 +205,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => try { - - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias + // The foldLeft adds GetFields for every remaining parts of the identifier, + // and aliases it with the last part of the identifier. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add GetField("c", GetField("b", a)), and alias // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver)) + val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver)) val aliasName = nestedFields.last Some(Alias(fieldExprs, aliasName)()) } catch { @@ -230,41 +229,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { s"Reference '$name' is ambiguous, could be: $referenceNames.") } } - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - * - * TODO: this code is duplicated from Analyzer and should be refactored to avoid this. - */ - protected def resolveGetField( - expr: Expression, - fieldName: String, - resolver: Resolver): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8633e06093cf3..17522976dc2c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -125,12 +125,14 @@ case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = child.output + assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) @@ -147,6 +149,18 @@ case class CreateTableAsSelect[T]( override lazy val resolved: Boolean = databaseName != None && childrenResolved } +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * @param child The final query of this CTE. + * @param cteRelations Queries that this CTE defined, + * key is the alias of the CTE definition, + * value is the CTE definition. + */ +case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala deleted file mode 100644 index a9d63e784963d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.text.SimpleDateFormat - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - -private[sql] object DataTypeConversions { - - def productToRow(product: Product, schema: StructType): Row = { - val mutableRow = new GenericMutableRow(product.productArity) - val schemaFields = schema.fields.toArray - - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType) - i += 1 - } - - mutableRow - } - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } - - /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type - case (d: java.math.BigDecimal, _) => Decimal(d) - case (other, _) => other - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 34270d0ca7cd7..5163f05879e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -40,7 +40,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { protected lazy val primitiveType: Parser[DataType] = "(?i)string".r ^^^ StringType | "(?i)float".r ^^^ FloatType | - "(?i)int".r ^^^ IntegerType | + "(?i)(?:int|integer)".r ^^^ IntegerType | "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 8a1a3b81b3d2c..504fb05842505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import java.sql.Date +import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast @@ -57,4 +58,32 @@ object DateUtils { } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } } diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 46b2250aab231..ea82cd2622de9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -30,7 +30,7 @@ class DistributionSuite extends FunSuite { inputPartitioning: Partitioning, requiredDistribution: Distribution, satisfied: Boolean) { - if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) { fail( s""" |== Input Partitioning == @@ -40,6 +40,7 @@ class DistributionSuite extends FunSuite { |== Does input partitioning satisfy required distribution? == |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} """.stripMargin) + } } test("HashPartitioning is the output partitioning") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index eee00e3f7ea76..bbc0b661a0c0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -260,7 +260,7 @@ class ScalaReflectionSuite extends FunSuite { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { @@ -270,7 +270,7 @@ class ScalaReflectionSuite extends FunSuite { val dataType = schemaFor[OptionalData].dataType val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, Row(1, 1, 1, 1, 1, 1, true)) - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ee7b14c7a157c..6e3d6b9263e86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import scala.collection.immutable + class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) @@ -41,10 +43,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } - def caseSensitiveAnalyze(plan: LogicalPlan) = + def caseSensitiveAnalyze(plan: LogicalPlan): Unit = caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) - def caseInsensitiveAnalyze(plan: LogicalPlan) = + def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) @@ -147,7 +149,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { name: String, plan: LogicalPlan, errorMessages: Seq[String], - caseSensitive: Boolean = true) = { + caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { if(caseSensitive) { @@ -202,7 +204,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false - override def output = Nil + override def output: Seq[Attribute] = Nil } errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 70aef1cac421a..fcd745f43cfbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -96,7 +96,9 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(StringType, TimestampType, None) // ComplexType - widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) @@ -113,7 +115,9 @@ class HiveTypeCoercionSuite extends PlanTest { // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) // Stringify boolean when casting to string. - ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) + ruleTest( + Cast(Literal(false), StringType), + If(Literal(false), Literal("true"), Literal("false"))) } test("coalesce casts") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 3dbefa40d2808..d4362a91d992c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -82,10 +82,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) } + // scalastyle:off /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False @@ -102,7 +105,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { * False True * Unknown Unknown */ - + // scalastyle:on val notTrueTable = (true, false) :: (false, true) :: @@ -165,7 +168,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation( + In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), + true) } test("Divide") { @@ -180,7 +185,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("Remainder") { @@ -195,7 +201,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("INSET") { @@ -226,6 +233,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) @@ -264,7 +281,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation(Literal.create(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation(Literal.create(null, StringType) like regEx, null, + new GenericRow(Array[Any]("bc%"))) } test("RLIKE literal Regular Expression") { @@ -507,8 +525,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("array casting") { - val array = Literal.create(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -765,7 +785,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) @@ -842,18 +863,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal.create(null, IntegerType)), null, row) - def quickBuildGetField(expr: Expression, fieldName: String) = { + def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -861,7 +884,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } } - def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) + def quickResolve(u: UnresolvedGetField): StructGetField = { + quickBuildGetField(u.child, u.fieldName) + } checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) @@ -872,7 +897,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { ) assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) @@ -896,7 +922,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Add(c1, c2), 3, row) checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -919,7 +946,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Add(c1, c2), 3.1, row) checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -942,7 +970,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(LessThan(c1, c2), true, row) checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -985,54 +1014,84 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, new GenericRow(Array[Any](null))) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1065,17 +1124,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index a0efe9e2e7f6b..4396bd0dda9a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -176,40 +176,39 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: expressions have null literals") { - val originalQuery = - testRelation - .select( - IsNull(Literal(null)) as 'c1, - IsNotNull(Literal(null)) as 'c2, + val originalQuery = testRelation.select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, - GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, - UnresolvedGetField( - Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), - "a") as 'c5, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem( + Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, + UnresolvedGetField( + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, - UnaryMinus(Literal.create(null, IntegerType)) as 'c6, - Cast(Literal(null), IntegerType) as 'c7, - Not(Literal.create(null, BooleanType)) as 'c8, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal.create(null, IntegerType), 1) as 'c9, - Add(1, Literal.create(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal.create(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal.create(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal.create(null, StringType), "abc") as 'c13, - Like("abc", Literal.create(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal.create(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal.create(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal.create(null, StringType), "abc") as 'c19, - Contains("abc", Literal.create(null, StringType)) as 'c20 - ) + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 + ) val optimized = Optimize(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55c6766520a1e..1448098c770aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -432,7 +432,8 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { z.join(x.join(y)) - .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && + ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } val optimized = Optimize(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 233e329cb2038..966bc9ada1e6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -52,7 +52,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2)) + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 129d091ca03e3..e7cafcc96de87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -45,12 +45,13 @@ class PlanTest extends FunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) + if (normalized1 != normalized2) { fail( s""" |== FAIL: Plans do not match === |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) + """.stripMargin) + } } /** Fails the test if the two expressions do not match */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 11e6831b24768..1273921f6394c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -32,7 +32,7 @@ class SameResultSuite extends FunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) - def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 274f3ede0045c..4eb8708335dcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children = optKey.toSeq - def nullable = true - def dataType = NullType + def children: Seq[Expression] = optKey.toSeq + def nullable: Boolean = true + def dataType: NullType = NullType override lazy val resolved = true type EvaluatedType = Any - def eval(input: Row) = null.asInstanceOf[Any] + def eval(input: Row): Any = null.asInstanceOf[Any] } class TreeNodeSuite extends FunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 1ba21b64603ac..169125264a803 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -34,10 +34,12 @@ class DataTypeParserSuite extends FunSuite { } checkDataType("int", IntegerType) + checkDataType("integer", IntegerType) checkDataType("BooLean", BooleanType) checkDataType("tinYint", ByteType) checkDataType("smallINT", ShortType) checkDataType("INT", IntegerType) + checkDataType("INTEGER", IntegerType) checkDataType("bigint", LongType) checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5c6016a4a2ce2..94ae2d65fd0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -713,7 +713,7 @@ class DataFrame private[sql]( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributes = schema.toAttributes val rowFunction = - f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, None, logicalPlan) @@ -734,7 +734,7 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) @@ -961,7 +961,10 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + queryExecution.executedPlan.execute().mapPartitions { rows => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]) + } } /** @@ -1206,7 +1209,7 @@ class DataFrame private[sql]( @Experimental def insertInto(tableName: String, overwrite: Boolean): Unit = { sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd + Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index bf3c3fe876873..481ed4924857e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { + replace[T](col, replacement.toMap : Map[T, T]) + } + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { + replace(cols.toSeq, replacement.toMap) + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", Map(1.0 -> 2.0)) + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", Map("UNKNOWN" -> "unnamed") + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", Map("UNKNOWN" -> "unnamed") + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: Map[T, T]): DataFrame = { + if (col == "*") { + replace0(df.columns, replacement) + } else { + replace0(Seq(col), replacement) + } + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) + + private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { + if (replacement.isEmpty || cols.isEmpty) { + return df + } + + // replacementMap is either Map[String, String] or Map[Double, Double] + val replacementMap: Map[_, _] = replacement.head._2 match { + case v: String => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } + + // targetColumnType is either DoubleType or StringType + val targetColumnType = replacement.head._1 match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: String => StringType + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) + if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { + replaceCol(f, replacementMap) + } else if (f.dataType == targetColumnType && shouldReplace) { + replaceCol(f, replacementMap) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + private def fill0(values: Seq[(String, Any)]): DataFrame = { // Error handling values.foreach { case (colName, replaceValue) => @@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { private def fillCol[T](col: StructField, replacement: T): Column = { coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) } + + /** + * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with + * value in `replacementMap`, using [[CaseWhen]]. + * + * TODO: This can be optimized to use broadcast join when replacementMap is large. + */ + private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { + val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) => + df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr :: + lit(target).cast(col.dataType).expr :: Nil + }.toSeq + new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name) + } + + private def convertToDouble(v: Any): Double = v match { + case v: Float => v.toDouble + case v: Double => v + case v: Long => v.toDouble + case v: Int => v.toDouble + case v => throw new IllegalArgumentException( + s"Unsupported value type ${v.getClass.getName} ($v).") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a5e6b638d2150..53ad67372e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4815620c6fe57..ee641bdfeb2d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -39,6 +39,8 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -119,6 +121,10 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ + private[spark] def verifyPartitionPath = + getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39dd14e796f06..c25ef58e6f62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,9 +31,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ @@ -404,7 +404,8 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { - rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + rowRDD.map(converter(_).asInstanceOf[Row]) } else { rowRDD } @@ -459,7 +460,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { row => new GenericRow( extractors.zip(attributeSeq).map { case (e, attr) => - DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) }.toArray[Any] ) : Row } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 0000000000000..d1ea7cc3e9162 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d8955725e59b1..656bdd7212f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType -import scala.collection.immutable - /** * :: DeveloperApi :: */ @@ -39,13 +37,15 @@ object RDDConversions { Iterator.empty } else { val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) + mutableRow(i) = converters(i)(r.productElement(i)) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index a8018b9213f2b..b1ef6556de1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -68,6 +68,8 @@ case class GeneratedAggregate( a.collect { case agg: AggregateExpression => agg} } + // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite + // (in test "aggregation with codegen"). val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -95,11 +97,14 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", calcType, nullable = true)() val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) val updateFunction = Coalesce( - Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -109,8 +114,8 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => - val calcType = + case cs @ CombineSum(expr) => + val calcType = expr.dataType expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -118,42 +123,36 @@ case class GeneratedAggregate( expr.dataType } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val currentSum = AttributeReference("currentSum", calcType, nullable = false)() - val initialCount = Literal(0L) - val initialSum = Cast(Literal(0L), calcType) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { + val actualExpr = expr match { case UnscaledValue(e) => e case _ => expr } - - val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil) - + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + val result = expr.dataType match { case DecimalType.Fixed(_, _) => - If(EqualTo(currentCount, Literal(0L)), - Literal.create(null, a.dataType), - Cast(Divide( - Cast(currentSum, DecimalType.Unlimited), - Cast(currentCount, DecimalType.Unlimited)), a.dataType)) - case _ => - If(EqualTo(currentCount, Literal(0L)), - Literal.create(null, a.dataType), - Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) + Cast(currentSum, cs.dataType) + case _ => currentSum } - AggregateEvaluation( - currentCount :: currentSum :: Nil, - initialCount :: initialSum :: Nil, - updateCount :: updateSum :: Nil, - result - ) - + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) @@ -165,8 +164,20 @@ case class GeneratedAggregate( updateMax :: Nil, currentMax) + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + case CollectHashSet(Seq(expr)) => - val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val set = + AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() val initialValue = NewSet(expr.dataType) val addToSet = AddItemToSet(expr, set) @@ -177,9 +188,10 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => - val ArrayType(inputType, _) = inputSet.dataType - val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(inputType) + val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType + val set = + AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() + val initialValue = NewSet(elementType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( @@ -187,6 +199,8 @@ case class GeneratedAggregate( initialValue :: Nil, collectSets :: Nil, CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 5bd699a2fa949..8a8c3a404323a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.Attribute @@ -32,9 +32,15 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo override def execute(): RDD[Row] = rdd - override def executeCollect(): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).toArray + } + + + override def executeTake(limit: Int): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d239637cd4b4e..fabcf6b4a0570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{ScalaReflection, trees} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -80,8 +80,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ + def executeCollect(): Array[Row] = { - execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + execute().mapPartitions { iter => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + iter.map(converter(_).asInstanceOf[Row]) + }.collect() } /** @@ -125,7 +129,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + val converter = CatalystTypeConverters.createToScalaConverter(schema) + buf.toArray.map(converter(_).asInstanceOf[Row]) } protected def newProjection( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 967bd76b302d8..914f387dec78f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.types.Decimal @@ -26,14 +27,13 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.{AllScalaRegistrar, ResourcePool} +import com.twitter.chill.ResourcePool import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} @@ -55,6 +55,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) kryo.register(classOf[Decimal]) + kryo.register(classOf[JavaHashMap[_, _]]) kryo.setReferences(false) kryo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f754fa770d1b5..5b99e40c2f491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false @@ -211,9 +211,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil - case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: ParquetRelation, partition, child, overwrite, ifNotExists) => InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => + val partitionColNames = relation.partitioningAttributes.map(_.name).toSet + val filtersToPush = filters.filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } val prunePushedDownFilters = if (sqlContext.conf.parquetFilterPushDown) { (predicates: Seq[Expression]) => { @@ -225,6 +231,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // "A AND B" in the higher-level filter, not just "B". predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { case (predicate, None) => predicate + // Filter needs to be applied above when it contains partitioning + // columns + case (predicate, _) if(!predicate.references.map(_.name).toSet + .intersect (partitionColNames).isEmpty) => predicate } } } else { @@ -237,7 +247,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetTableScan( _, relation, - if (sqlContext.conf.parquetFilterPushDown) filters else Nil)) :: Nil + if (sqlContext.conf.parquetFilterPushDown) filtersToPush else Nil)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1f5251a20376f..f8221f41bc6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -139,9 +139,10 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) - // TODO: Is this copying for no reason? - override def executeCollect(): Array[Row] = - collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + collectData().map(converter(_).asInstanceOf[Row]) + } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. @@ -193,7 +194,7 @@ case class ExternalSort( child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r, null))) + sorter.insertAll(iterator.map(r => (r.copy, null))) val baseIterator = sorter.iterator.map(_._1) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2fa1cf5add3b5..ab84c123e0c0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.joins +import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -29,16 +31,43 @@ import org.apache.spark.util.collection.CompactBuffer */ private[joins] sealed trait HashedRelation { def get(key: Row): CompactBuffer[Row] + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) - extends HashedRelation with Serializable { +private[joins] final class GeneralHashedRelation( + private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } @@ -46,8 +75,10 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) - extends HashedRelation with Serializable { +private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) @@ -55,6 +86,14 @@ private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, R } def getValue(key: Row): Row = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 111e751588a8b..ff91e1d74bc2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -605,4 +605,23 @@ object functions { } // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + */ + def callUdf(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0b770f2251943..b1e8521383756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -391,7 +391,7 @@ private[sql] object JsonRDD extends Logging { value match { // only support string as date case value: java.lang.String => - DateUtils.millisToDays(DataTypeConversions.stringToTime(value).getTime) + DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) case value: java.sql.Date => DateUtils.fromJavaDate(value) } } @@ -400,7 +400,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..25a66cb488103 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + try { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1c868da23e060..3724bda829d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -379,6 +379,8 @@ private[sql] case class InsertIntoParquetTable( */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { + var committer: OutputCommitter = null + // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} @@ -403,6 +405,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 0dce3623a66df..20fdf5e58ef82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -432,7 +432,10 @@ private[sql] case class ParquetRelation2( // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) + FileInputFormat.setInputPaths(job, selectedPaths: _*) } // Try to push down filters when filter push-down is enabled. @@ -484,10 +487,31 @@ private[sql] case class ParquetRelation2( val cacheMetadata = useCache @transient - val cachedStatus = selectedFiles + val cachedStatus = selectedFiles.map { st => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val newPath = new Path(st.getPath.toUri.toString) + + new FileStatus( + st.getLen, + st.isDir, + st.getReplication, + st.getBlockSize, + st.getModificationTime, + st.getAccessTime, + st.getPermission, + st.getOwner, + st.getGroup, + newPath) + } @transient - val cachedFooters = selectedFooters + val cachedFooters = selectedFooters.map { f => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index e13759b7feb7b..34d048e426d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -56,7 +56,7 @@ private[sql] object DataSourceStrategy extends Strategy { execution.PhysicalRDD(l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if part.isEmpty => + l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 5a78001117d1b..6ed68d179edc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -37,7 +37,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // We are inserting into an InsertableRelation. case i @ InsertIntoTable( - l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite) => { + l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -84,7 +84,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => + l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite, ifNotExists) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -102,7 +102,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } case i @ logical.InsertIntoTable( - l: LogicalRelation, partition, query, overwrite) if !l.isInstanceOf[InsertableRelation] => + l: LogicalRelation, partition, query, overwrite, ifNotExists) + if !l.isInstanceOf[InsertableRelation] => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1ff2d5a190521..6d0fbe83c2f36 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,6 +20,8 @@ import java.io.Serializable; import java.util.Arrays; +import scala.collection.Seq; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -127,6 +129,12 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("b")); Row first = df.select("a", "b").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - Assert.assertArrayEquals(bean.getB(), first.getAs(1)); + // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // verify that it has the expected length, and contains expected elements. + Seq result = first.getAs(1); + Assert.assertEquals(bean.getB().length, result.length()); + for (int i = 0; i < result.length(); i++) { + Assert.assertEquals(bean.getB()[i], result.apply(i)); + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c240f2be955ca..f7b5f08beb92f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -92,7 +92,8 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + .registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 0896f175c056f..41b4f02e6a294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -154,4 +154,38 @@ class DataFrameNaFunctionsSuite extends QueryTest { ))), Row("test", null, 1, 2.2)) } + + test("replace") { + val input = createDF() + + // Replace two numeric columns: age and height + val out = input.na.replace(Seq("age", "height"), Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 461.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + + // Replace only the age column + val out1 = input.na.replace("age", Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out1, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 164.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1db0cf7daac03..b26e22f6229fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -329,8 +329,9 @@ class DataFrameSuite extends QueryTest { checkAnswer( decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) + // non-partial checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } @@ -439,6 +440,15 @@ class DataFrameSuite extends QueryTest { ) } + test("call udf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9b4dd6c620fec..9a81fc5d72819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,7 +67,7 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 87e7cf8c8af9f..73fb791c3ead7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -102,11 +104,105 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() + // Prepare a table that we can group some rows. + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } @@ -182,7 +278,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"), + """ + |SELECT time FROM timestamps + |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') + """.stripMargin), Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) @@ -248,7 +347,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1")) } - def sortTest() = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) @@ -318,6 +417,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("CTE feature") { + checkAnswer( + sql("with q1 as (select * from testData limit 10) select * from q1"), + testData.take(10).toSeq) + + checkAnswer( + sql(""" + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), + Row(5, "5") :: Row(4, "4") :: Nil) + + } + test("date row") { checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), @@ -327,7 +440,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("from follow multiple brackets") { checkAnswer(sql( - "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + """ + |select key from ((select * from testData limit 1) + | union all (select * from testData limit 1)) x limit 1 + """.stripMargin), Row(1) ) @@ -337,7 +453,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer(sql( - "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + """ + |select key from + | (select * from testData limit 1 union all select * from testData limit 1) x + | limit 1 + """.stripMargin), Row(1) ) } @@ -384,7 +504,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), + sql( + """ + |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 + """.stripMargin), Row(2, 1, 2, 2, 1)) } @@ -997,7 +1120,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = sparkContext.parallelize( + Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") sql("SELECT `key?number1` FROM records") } @@ -1082,8 +1206,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: ORDER BY test for nested fields") { - jsonRDD(sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 17e923ca48502..3fa00fd9d0ccb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -80,7 +80,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectData") @@ -103,7 +103,8 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fe618e0e8e767..2672e20deadc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,13 +23,16 @@ import org.apache.spark.util.Utils import scala.beans.{BeanInfo, BeanProperty} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - +import org.apache.spark.util.collection.OpenHashSet @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { @@ -63,7 +66,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } } - override def userClass = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this } @@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest { df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } + + test("HyperLogLogUDT") { + val hyperLogLogUDT = HyperLogLogUDT + val hyperLogLog = new HyperLogLog(0.4) + (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + + val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) + assert(actual.cardinality() === hyperLogLog.cardinality()) + assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + } + + test("OpenHashSetUDT") { + val openHashSetUDT = new OpenHashSetUDT(IntegerType) + val set = new OpenHashSet[Int] + (1 to 10).foreach(i => set.add(i)) + + val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) + assert(actual.iterator.toSet === set.iterator.toSet) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index c7a40845db16c..b301818a008e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{Decimal, DataType, NativeType} object ColumnarTestUtils { - def makeNullRow(length: Int) = { + def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row @@ -93,7 +93,7 @@ object ColumnarTestUtils { def makeUniqueValuesAndSingleValueRows[T <: NativeType]( columnType: NativeColumnType[T], - count: Int) = { + count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 27dfabca90217..479210d1c9c43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -42,7 +42,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { .toDF().registerTempTable("sizeTst") cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.logical.statistics.sizeInBytes > + table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > conf.autoBroadcastJoinThreshold) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index bb305355276bf..a0702144f942c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -31,7 +31,8 @@ class TestNullableColumnAccessor[T <: DataType, JvmType]( with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) + : TestNullableColumnAccessor[T, JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 75a47498683f4..3a5605d2335d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -27,7 +27,8 @@ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[T, JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 0b18b4119268f..fc8ff3b41d0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -35,7 +35,7 @@ object TestCompressibleColumnBuilder { def apply[T <: NativeType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - scheme: CompressionScheme) = { + scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) builder.initialize(0, "", useCompression = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4e9472c60249e..358d8cf06e463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -30,4 +30,4 @@ class DebuggingSuite extends FunSuite { test("DataFrame.typeCheck()") { testData.typeCheck() } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 592ed4b23b7d3..3596b183d4328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -45,10 +45,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( @@ -132,25 +134,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT *") { - assert(sql("SELECT * FROM foobar").collect().size == 3) + assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2) + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) } test("SELECT * WHERE (quoted strings)") { - assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1) + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) - assert(names.size == 3) + assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) @@ -158,10 +160,10 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("SELECT * partitioned") { @@ -169,46 +171,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size == 3) + .collect.size === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size == 1) - assert(rows(0).getInt(0) == 1) - assert(rows(0).getBoolean(1) == false) - assert(rows(0).getInt(2) == 3) - assert(rows(0).getInt(3) == 4) - assert(rows(0).getLong(4) == 1234567890123L) + assert(rows.size === 1) + assert(rows(0).getInt(0) === 1) + assert(rows(0).getBoolean(1) === false) + assert(rows(0).getInt(2) === 3) + assert(rows(0).getInt(3) === 4) + assert(rows(0).getLong(4) === 1234567890123L) } test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size == 1) + assert(rows.size === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -230,27 +232,27 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) - assert(cal.get(Calendar.HOUR_OF_DAY) == 12) - assert(cal.get(Calendar.MINUTE) == 34) - assert(cal.get(Calendar.SECOND) == 56) + assert(cal.get(Calendar.HOUR_OF_DAY) === 12) + assert(cal.get(Calendar.MINUTE) === 34) + assert(cal.get(Calendar.SECOND) === 56) cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) - assert(cal.get(Calendar.YEAR) == 1996) - assert(cal.get(Calendar.MONTH) == 0) - assert(cal.get(Calendar.DAY_OF_MONTH) == 1) + assert(cal.get(Calendar.YEAR) === 1996) + assert(cal.get(Calendar.MONTH) === 0) + assert(cal.get(Calendar.DAY_OF_MONTH) === 1) cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) - assert(cal.get(Calendar.YEAR) == 2002) - assert(cal.get(Calendar.MONTH) == 1) - assert(cal.get(Calendar.DAY_OF_MONTH) == 20) - assert(cal.get(Calendar.HOUR) == 11) - assert(cal.get(Calendar.MINUTE) == 22) - assert(cal.get(Calendar.SECOND) == 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543) + assert(cal.get(Calendar.YEAR) === 2002) + assert(cal.get(Calendar.MONTH) === 1) + assert(cal.get(Calendar.DAY_OF_MONTH) === 20) + assert(cal.get(Calendar.HOUR) === 11) + assert(cal.get(Calendar.MINUTE) === 22) + assert(cal.get(Calendar.SECOND) === 33) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) } test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. + assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. assert(rows(0).getAs[BigDecimal](2) .equals(new BigDecimal("123456789012345.54321543215432100000"))) } @@ -264,7 +266,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() - assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 706c966ee05f5..fd0e2746dc045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -380,8 +380,10 @@ class JsonSuite extends QueryTest { sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("false", 21474836470L, + new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, + new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -404,7 +406,8 @@ class JsonSuite extends QueryTest { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: + Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType @@ -892,8 +895,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 DataFrame to JSON") - { + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: StructField("f2", StringType, false) :: @@ -913,8 +915,10 @@ class JsonSuite extends QueryTest { df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() + // scalastyle:off assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + // scalastyle:on val schema2 = StructType( StructField("f1", StructType( @@ -968,7 +972,8 @@ class JsonSuite extends QueryTest { // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) @@ -1008,7 +1013,8 @@ class JsonSuite extends QueryTest { // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + "from complexTable"), Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 6a2c2a7c4080a..10d0ede4dc0dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -22,7 +22,7 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext @@ -350,4 +350,26 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6742: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( + path, + Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, + Seq(AttributeReference("part", IntegerType, false)()) )) + + checkAnswer( + df.filter("a = 1 or part = 1"), + (1 to 3).map(i => Row(1, i, i.toString))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 203bc79f153dd..4d0bf7cf99cdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -218,7 +218,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("compression codec") { - def compressionCodecFor(path: String) = { + def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter .readMetaData(new Path(path), Some(configuration)) .getBlocks @@ -381,6 +381,27 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.unset("spark.sql.parquet.output.committer.class") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 61f1cf347ab0f..c964b6d984557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -180,10 +180,12 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { val caseClassString = "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + // scalastyle:off val jsonString = """ |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} """.stripMargin + // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) val fromJson = ParquetTypesConverter.convertFromString(jsonString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54af50c6e10ad..3f24a497390c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -31,7 +32,7 @@ class DDLScanSource extends RelationProvider { case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, new MetadataBuilder().putString("comment", "test comment").build()), @@ -57,8 +58,9 @@ case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLConte )) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to). - map(e => Row(s"people$e", e * 2)) + override def buildScan(): RDD[Row] = { + sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + } } class DDLTestSuite extends DataSourceTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 773bd1602d5e5..cb5e5147ff189 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -41,7 +42,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("b", IntegerType, nullable = false) :: StructField("c", StringType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 08fb5380dc026..6a1ddf2f8e98b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -34,12 +35,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo extends BaseRelation with PrunedScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String]) = { + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 43bc8eb2d11a7..cb287ba85c1f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -114,4 +114,4 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { message.contains("Append mode is not supported"), "We should complain that 'Append mode is not supported' for JSON source.") } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 7928600ac2fb5..60c8c00bda4d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.sql.{Timestamp, Date} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -35,10 +36,10 @@ class SimpleScanSource extends RelationProvider { case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -57,9 +58,9 @@ case class AllDataTypesScan( extends BaseRelation with TableScan { - override def schema = userSpecifiedSchema + override def schema: StructType = userSpecifiedSchema - override def buildScan() = { + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6272cdedb3e48..62c061bef690a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -264,7 +264,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 158c225159720..97b46a01ba5b4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -37,7 +38,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") sparkConf - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .setAppName(s"SparkSQL::${Utils.localHostName()}") .set("spark.sql.hive.version", HiveShim.version) .set( "spark.serializer", diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 75738fa22b572..6d1d7c3a4e698 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index bf20acecb1f32..4cf95e7bdfb2b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File +import java.net.URL import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -41,7 +42,7 @@ import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils object TestData { - def getTestDataFilePath(name: String) = { + def getTestDataFilePath(name: String): URL = { Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") } @@ -50,7 +51,7 @@ object TestData { } class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.binary + override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { // Transport creation logics below mimics HiveConnection.createBinaryTransport @@ -337,7 +338,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.http + override def mode: ServerMode.Value = ServerMode.http test("JDBC query execution") { withJdbcStatement { statement => diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2ae9d018e1b1b..81ee48ef4152f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -532,6 +532,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl7", "inputddl8", "insert1", + "insert1_overwrite_partitions", "insert2_overwrite_partitions", "insert_compressed", "join0", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 315fab673da5c..f1c0bd92aa23d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -279,7 +279,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } - if (metastoreRelation.hiveQlTable.isPartitioned) { + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) val partitions = metastoreRelation.hiveQlPartitions.map { p => @@ -314,6 +314,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with parquetRelation } + + result.newInstance() } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { @@ -525,7 +527,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Collects all `MetastoreRelation`s which should be replaced val toBeReplaced = plan.collect { // Write path - case InsertIntoTable(relation: MetastoreRelation, _, _, _) + case InsertIntoTable(relation: MetastoreRelation, _, _, _, _) // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && @@ -536,7 +538,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with (relation, parquetRelation, attributedRewrites) // Write path - case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) + case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _, _) // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && @@ -567,15 +569,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val alias = r.alias.getOrElse(r.tableName) Subquery(alias, parquetRelation) - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) if relationMap.contains(r) => val parquetRelation = relationMap(r) - InsertIntoTable(parquetRelation, partition, child, overwrite) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) if relationMap.contains(r) => val parquetRelation = relationMap(r) - InsertIntoTable(parquetRelation, partition, child, overwrite) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) case other => other.transformExpressions { case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) @@ -696,7 +698,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => castChildOutput(p, table, child) } @@ -713,7 +715,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) + InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -751,7 +753,8 @@ private[hive] case class InsertIntoHiveTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 077e64133faad..53a204b8c2932 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.sql.Date +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} + import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf @@ -111,13 +113,16 @@ private[hive] object HiveQl { "TOK_REVOKE", + "TOK_SHOW_COMPACTIONS", "TOK_SHOW_CREATETABLE", "TOK_SHOW_GRANT", "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", "TOK_SHOW_ROLES", "TOK_SHOW_SET_ROLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", "TOK_SHOWCOLUMNS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", @@ -574,11 +579,23 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - val (fromClause: Option[ASTNode], insertClauses) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - (Some(args.head), insertClauses) - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs) - } + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) + } + + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -792,7 +809,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } // If there are multiple INSERTS just UNION them together into on query. - queries.reduceLeft(Union) + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) @@ -982,7 +1002,27 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite) + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.getChildren.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") @@ -1284,8 +1324,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Explode(attributes, nodeToExpr(child)) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUdtf( - new HiveFunctionWrapper(functionName), + new HiveFunctionWrapper(functionClassName), attributes, children.map(nodeToExpr)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5f7e897295117..1ccb0c279c60e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -184,12 +184,14 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil - case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) => + table, partition, planLater(child), overwrite, ifNotExists) :: Nil + case hive.InsertIntoHiveTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil + table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 3563472c7ae81..e556c74ffb015 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DateUtils +import org.apache.spark.util.Utils /** * A trait for subclasses that handle table scans. @@ -76,7 +77,9 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + Class.forName( + relation.tableDesc.getSerdeClassName, true, sc.sessionState.getConf.getClassLoader) + .asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -142,7 +145,46 @@ class HadoopTableReader( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[Row] = { - val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + + // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists + def verifyPartitionPath( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): + Map[HivePartition, Class[_ <: Deserializer]] = { + if (!sc.conf.verifyPartitionPath) { + partitionToDeserializer + } else { + var existPathSet = collection.mutable.Set[String]() + var pathPatternSet = collection.mutable.Set[String]() + partitionToDeserializer.filter { + case (partition, partDeserializer) => + def updateExistPathSetByPathPattern(pathPatternStr: String) { + val pathPattern = new Path(pathPatternStr) + val fs = pathPattern.getFileSystem(sc.hiveconf) + val matches = fs.globStatus(pathPattern) + matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) + } + // convert /demo/data/year/month/day to /demo/data/*/*/*/ + def getPathPatternByPath(parNum: Int, tempPath: Path): String = { + var path = tempPath + for (i <- (1 to parNum)) path = path.getParent + val tails = (1 to parNum).map(_ => "*").mkString("/", "/", "/") + path.toString + tails + } + + val partPath = HiveShim.getDataLocationPath(partition) + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); + var pathPatternStr = getPathPatternByPath(partNum, partPath) + if (!pathPatternSet.contains(pathPatternStr)) { + pathPatternSet += pathPatternStr + updateExistPathSetByPathPattern(pathPatternStr) + } + existPathSet.contains(partPath.toString) + } + } + } + + val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) + .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = HiveShim.getDataLocationPath(partition) val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index fade9e5852eaa..76a1965f3cb25 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -67,7 +67,7 @@ case class CreateTableAsSelect( new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 6c96747439683..89995a91b1a92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -45,7 +45,8 @@ case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], child: SparkPlan, - overwrite: Boolean) extends UnaryNode with HiveInspectors { + overwrite: Boolean, + ifNotExists: Boolean) extends UnaryNode with HiveInspectors { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @@ -219,15 +220,25 @@ case class InsertIntoHiveTable( isSkewedStoreAsSubdir) } } else { - catalog.synchronized { - catalog.client.loadPartition( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + // scalastyle:off + // ifNotExists is only valid with static partition, refer to + // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries + // scalastyle:on + val oldPart = catalog.synchronized { + catalog.client.getPartition( + catalog.client.getTable(qualifiedTableName), partitionSpec, false) + } + if (oldPart == null || !ifNotExists) { + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } } } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 99dc58646ddd6..902a12785e3e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -80,7 +80,7 @@ case class AddJar(path: String) extends RunnableCommand { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) - Seq.empty[Row] + Seq(Row(0)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a3497eadd67f6..6570fa1043900 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -262,12 +262,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |WITH SERDEPROPERTIES ('field.delim'='\\t') """.stripMargin.cmd, "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), - TestTable("sales", - s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) - |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' diff --git a/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 new file mode 100644 index 0000000000000..f6ba75da254ca --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 @@ -0,0 +1,3 @@ +5 +5 +5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 new file mode 100644 index 0000000000000..ca7b591095e28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 @@ -0,0 +1,4 @@ +val_4 +val_5 +val_5 +val_5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 new file mode 100644 index 0000000000000..b8626c4cff284 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 @@ -0,0 +1 @@ +4 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 0000000000000..185a91c110d6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 0000000000000..185a91c110d6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea index 25ce912507d55..a1963ba81e0da 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea +++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 index 25ce912507d55..a1963ba81e0da 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 +++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar new file mode 100644 index 0000000000000..37af9aafad8a4 Binary files /dev/null and b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 968557c9c4686..d960a30e00738 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -136,7 +136,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { * @param query the query to analyze * @param token a unique token in the string that should be indicated by the exception */ - def positionTest(name: String, query: String, token: String) = { + def positionTest(name: String, query: String, token: String): Unit = { def parseTree = Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index c482c6de8a736..2a7374cc172b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -116,21 +116,20 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { - dt1.zip(dt2).map { - case (dd1, dd2) => - assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + dt1.zip(dt2).foreach { case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info } } def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { - row1.zip(row2).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2).foreach { case (r1, r2) => + checkValue(r1, r2) } } def checkValues(row1: Seq[Any], row2: Row): Unit = { - row1.zip(row2.toSeq).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2.toSeq).foreach { case (r1, r2) => + checkValue(r1, r2) } } @@ -141,7 +140,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { assert(r1.compare(r2) === 0) case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => - r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } + r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } case (r1, r2) => assert(r1 === r2) } } @@ -166,7 +165,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) + val constantNullWritableOIs = + constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -202,7 +202,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, + unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } @@ -212,8 +213,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -222,7 +225,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 8011952e0d535..ecb990e8aac91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -115,11 +115,36 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + sql( + s""" + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='1') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='2') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='3') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='4') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() val folders = dir.filter(_.isDirectory).toList @@ -196,34 +221,42 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { testData.registerTempTable("testData") val testDatawithNull = TestHive.sparkContext.parallelize( - (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + sql( + s""" + |CREATE TABLE table_with_partition(key int,value string) + |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (ds='1') SELECT key,value FROM testData + """.stripMargin) // test schema the same between partition and table sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // add column to table sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), - testDatawithNull.collect.toSeq + testDatawithNull.collect().toSeq ) // change column name to table sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) sql("DROP TABLE table_with_partition") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala new file mode 100644 index 0000000000000..a787fa5546e76 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.google.common.io.Files + +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils + + +class QueryPartitionSuite extends QueryTest { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + test("SPARK-5068: query data when path doesn't exists"){ + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + + // delete the path of one partition + val folders = tmpDir.listFiles.filter(_.isDirectory) + Utils.deleteRecursively(folders(0)) + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect) + + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ccd0e5aa51f95..00a69de9e4262 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -142,7 +142,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { after: () => Unit, query: String, expectedAnswer: Seq[Row], - ct: ClassTag[_]) = { + ct: ClassTag[_]): Unit = { before() var df = sql(query) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 42a82c1fbf5c7..a3f5921a0cb23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class BigDataBenchmarkSuite extends HiveComparisonTest { val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( "rankings", @@ -63,7 +64,7 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | searchWord STRING, | duration INT) | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," - | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), TestTable( "documents", @@ -83,7 +84,10 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") createQueryTest("query2", - "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + """ + |SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits + |GROUP BY SUBSTR(sourceIP, 1, 10) + """.stripMargin) createQueryTest("query3", """ @@ -113,8 +117,8 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { |CREATE TABLE url_counts_total AS | SELECT SUM(count) AS totalCount, destpage | FROM url_counts_partial GROUP BY destpage - |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic - |-- given different input splits. + |-- The following queries run, but generate different results in HIVE + |-- likely because the UDF is not deterministic given different input splits. |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial |-- SELECT COUNT(*) FROM url_counts_partial |-- SELECT * FROM url_counts_partial diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index a5ec312ee430c..027056d4b865f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -255,8 +255,9 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") .filterNot(_ contains "hive.exec.post.hooks") - if (allQueries != queryList) + if (allQueries != queryList) { logWarning(s"Simplifications made on unsupported operations for test $testCaseName") + } lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -305,13 +306,16 @@ abstract class HiveComparisonTest try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. - if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) { sys.error("hive exec hooks not supported for tests.") + } - logWarning(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i + 1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { - case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _: ExplainCommand => + // No need to execute EXPLAIN queries as we don't check the output. + Nil case _ => TestHive.runSqlHive(queryString) } @@ -394,21 +398,24 @@ abstract class HiveComparisonTest case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => if (System.getProperty("spark.hive.canarytest") != null) { - // When we encounter an error we check to see if the environment is still okay by running a simple query. - // If this fails then we halt testing since something must have gone seriously wrong. + // When we encounter an error we check to see if the environment is still + // okay by running a simple query. If this fails then we halt testing since + // something must have gone seriously wrong. try { new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") - // The testing setup traps exits so wait here for a long time so the developer can see when things started - // to go wrong. + logError(s"FATAL ERROR: Canary query threw $e This implies that the " + + "testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer + // can see when things started to go wrong. Thread.sleep(1000000) } } - // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + // If the canary query didn't fail then the environment is still okay, + // so just throw the original exception. throw originalException } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 02518d516261b..f7b37dae0a5f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. - * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * that should be included. Additionally, there is support for whitelisting and blacklisting + * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ @@ -54,15 +55,17 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. - if(System.getProperty(whiteListProperty) == null && !runAll) + if (System.getProperty(whiteListProperty) == null && !runAll) { ignore(testCaseName) {} + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index de140fc72a2c3..300b1f7920473 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault @@ -237,7 +238,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } createQueryTest("modulus", - "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + + "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") @@ -309,7 +311,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT * FROM src a JOIN src b ON a.key = b.key") createQueryTest("small.cartesian", - "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN " + + "(SELECT key FROM src WHERE key = 2) b") createQueryTest("length.udf", "SELECT length(\"test\") FROM src LIMIT 1") @@ -457,6 +460,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + // scalastyle:off createQueryTest("lateral view4", """ |create table src_lv1 (key string, value string); @@ -466,6 +470,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX """.stripMargin) + // scalastyle:on createQueryTest("lateral view5", "FROM src SELECT explode(array(key+3, key+4))") @@ -537,6 +542,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("select null from table", "SELECT null FROM src LIMIT 1") + createQueryTest("CTE feature #1", + "with q1 as (select key from src) select * from q1 where key = 5") + + createQueryTest("CTE feature #2", + """with q1 as (select * from src where key= 5), + |q2 as (select * from src s2 where key = 4) + |select value from q1 union all select value from q2 + """.stripMargin) + + createQueryTest("CTE feature #3", + """with q1 as (select key from src) + |from q1 + |select * where key = 4 + """.stripMargin) + test("predicates contains an empty AttributeSet() references") { sql( """ @@ -584,7 +604,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: DataFrame) = { + def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } @@ -793,6 +813,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE alter1") } + test("ADD JAR command 2") { + // this is a test case from mapjoin_addjar.q + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + if (HiveShim.version == "0.13.1") { + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index f4440e5b7846a..8ad3627504229 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -25,7 +25,8 @@ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7486bfa82b00b..5586a793618bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -25,17 +25,25 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validates support for Hive SerDe. */ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { - - override def beforeAll() = { + override def beforeAll(): Unit = { + import TestHive._ + import org.apache.hadoop.hive.serde2.RegexSerDe + super.beforeAll() TestHive.cacheTables = false + sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin) + sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } + // table sales is not a cache table, and will be clear after reset + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales", false) + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") - createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") - createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index ab0e0443c7faa..f0f04f8c73fb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,8 +35,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { val nullVal = "null" baseTypes.init.foreach { i => - createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") - createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + createQueryTest(s"case when then $i else $nullVal end ", + s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", + s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index d7c5d1a25a82b..7f49eac490572 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -123,9 +123,10 @@ class HiveUdfSuite extends QueryTest { IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") @@ -141,7 +142,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") @@ -156,7 +157,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") @@ -170,7 +171,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") @@ -187,7 +188,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") @@ -247,7 +248,8 @@ class PairUdf extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6c..067b577f1560e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -143,7 +143,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { sql: String, expectedOutputColumns: Seq[String], expectedScannedColumns: Seq[String], - expectedPartValues: Seq[Seq[String]]) = { + expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 817b9dcb8f505..47b4cb9ca61ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,12 +34,95 @@ case class Nested3(f3: Int) case class NestedArray2(b: Seq[Int]) case class NestedArray1(a: NestedArray2) +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-6835: udtf in lateral view") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") @@ -422,7 +505,7 @@ class SQLQuerySuite extends QueryTest { } test("resolve udtf with single alias") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") @@ -435,7 +518,7 @@ class SQLQuerySuite extends QueryTest { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val originalConf = getConf("spark.sql.hive.convertCTAS", "false") setConf("spark.sql.hive.convertCTAS", "false") @@ -478,5 +561,4 @@ class SQLQuerySuite extends QueryTest { sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 5f71e1bbc2d2e..d5dd0bf58e702 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -887,7 +886,11 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"SPARK-5775 read struct from $table") { checkAnswer( - sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"), + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), (1 to 10).map(i => Row(1, i, f"${i}_string"))) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 28703ef8129b3..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -139,8 +139,11 @@ class CheckpointWriter( // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) - fos.write(bytes) - fos.close() + Utils.tryWithSafeFinally { + fos.write(bytes) + } { + fos.close() + } // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail @@ -187,9 +190,11 @@ class CheckpointWriter( val bos = new ByteArrayOutputStream() val zos = compressionCodec.compressedOutputStream(bos) val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } try { executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) @@ -248,18 +253,24 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() + var ois: ObjectInputStreamWithLoader = null + var cp: Checkpoint = null + Utils.tryWithSafeFinally { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + cp = ois.readObject.asInstanceOf[Checkpoint] + } { + if (ois != null) { + ois.close() + } + } cp.validate() logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 42514d8b47dcf..f4963a78e1d18 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{SystemClock, Utils} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -79,9 +79,9 @@ private[streaming] class BlockGenerator( private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) private val clock = new SystemClock() - private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") private val blockIntervalTimer = - new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator") + new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @@ -132,7 +132,7 @@ private[streaming] class BlockGenerator( val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[Any] if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockInterval) + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) val newBlock = new Block(blockId, newBlockBuffer) listener.onGenerateBlock(blockId) blocksForPushing.put(newBlock) // put is blocking when queue is full diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 4946806d2ee95..58e56638a2dca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -24,7 +24,7 @@ import akka.actor.{ActorRef, Props, Actor} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, ManualClock} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -104,17 +104,15 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") val timeWhenStopStarted = System.currentTimeMillis() - val stopTimeout = conf.getLong( - "spark.streaming.gracefulStopTimeout", - 10 * ssc.graph.batchDuration.milliseconds - ) + val stopTimeoutMs = conf.getTimeAsMs( + "spark.streaming.gracefulStopTimeout", s"${10 * ssc.graph.batchDuration.milliseconds}ms") val pollTime = 100 // To prevent graceful stop to get stuck permanently def hasTimedOut: Boolean = { - val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeoutMs if (timedOut) { - logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") + logWarning("Timed out while stopping the job generator (timeout = " + stopTimeoutMs + ")") } timedOut } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index d6a93acbe711b..95f1857b4c377 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -105,6 +105,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (jobSet.jobs.isEmpty) { logInfo("No jobs added for time " + jobSet.time) } else { + listenerBus.post(StreamingListenerBatchSubmitted(jobSet.toBatchInfo)) jobSets.put(jobSet.time, jobSet) jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + jobSet.time) @@ -134,10 +135,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) - if (!jobSet.hasStarted) { + val isFirstJobOfJobSet = !jobSet.hasStarted + jobSet.handleJobStart(job) + if (isFirstJobOfJobSet) { + // "StreamingListenerBatchStarted" should be posted after calling "handleJobStart" to get the + // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index e4bd067cacb77..be1e8686cf9fa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -33,7 +33,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private val waitingBatchInfos = new HashMap[Time, BatchInfo] private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedaBatchInfos = new Queue[BatchInfo] + private val completedBatchInfos = new Queue[BatchInfo] private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L @@ -62,7 +62,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + waitingBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo } } @@ -79,8 +79,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) synchronized { waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + completedBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedBatchInfos.size > batchInfoLimit) completedBatchInfos.dequeue() totalCompletedBatches += 1L batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => @@ -118,7 +118,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedaBatchInfos.toSeq + completedBatchInfos.toSeq } def processingDelayDistribution: Option[Distribution] = synchronized { @@ -149,7 +149,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) }.toMap } - def lastReceivedBatchRecords: Map[Int, Long] = { + def lastReceivedBatchRecords: Map[Int, Long] = synchronized { val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => (0 until numReceivers).map { receiverId => @@ -160,24 +160,24 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = synchronized { receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = { - completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchInfo] = synchronized { + completedBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = { + def lastReceivedBatch: Option[BatchInfo] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = synchronized { + private def retainedBatches: Seq[BatchInfo] = { (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + runningBatchInfos.values.toSeq ++ completedBatchInfos).sortBy(_.batchTime)(Time.ordering) } private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + Distribution(completedBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index bfe8086fcf8fe..b6dcb62bfeec8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -37,11 +37,12 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val content = + val content = listener.synchronized { generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ generateBatchStatsTable() + } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd612..ca2f319f174a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -72,7 +72,8 @@ object RawTextSender extends Logging { } catch { case e: IOException => logError("Client disconnected") - socket.close() + } finally { + socket.close() } } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cf191715d29d6..87bc20f79c3cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 91a2b2bba461d..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { @@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 26435d8515815..0c4c06534a693 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging { val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) + override def batchDuration: Duration = Milliseconds(1000) - override def useManualClock = false + override def useManualClock: Boolean = false override def afterFunction() { Utils.deleteRecursively(directory) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7ed6320a3d0bc..e6ac4975c5e68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time @@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index ef4873de2f5a9..c090eaec2928d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -96,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -120,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 42fad769f0c1a..b63b37d9f9cef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index aa20ad0b5374e..91261a9db7360 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -131,11 +131,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) + val blockIntervalMs = 200 + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 5 - val waitTime = expectedBlocks * blockInterval + (blockInterval / 2) + val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -157,15 +157,15 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 100 + val blockIntervalMs = 100 val maxRate = 100 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 20 - val waitTime = expectedBlocks * blockInterval + val waitTime = expectedBlocks * blockIntervalMs val expectedMessages = maxRate * waitTime / 1000 - val expectedMessagesPerBlock = maxRate * blockInterval / 1000 + val expectedMessagesPerBlock = maxRate * blockIntervalMs / 1000 val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 2e5005ef6ff14..58353a5f97c8a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -73,9 +73,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from existing SparkContext") { @@ -85,24 +85,26 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) - assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") + assert( + Utils.timeStringAsSeconds(cp.sparkConfPairs + .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.createSparkConf().getInt("spark.cleaner.ttl", -1) === 10) + assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("start and stop state check") { @@ -176,7 +178,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600") + conf.set("spark.cleaner.ttl", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") @@ -207,13 +209,13 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.streaming.gracefulStopTimeout", "20000") + conf.set("spark.streaming.gracefulStopTimeout", "20000s") sc = new SparkContext(conf) logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 SlowTestReceiver.receivedAllRecords = false - //Create test receiver that sleeps in onStop() + // Create test receiver that sleeps in onStop() val totalNumRecords = 15 val recordsPerSecond = 1 val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) @@ -370,7 +372,8 @@ object TestReceiver { } /** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ -class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { var receivingThreadOption: Option[Thread] = None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f73..7210439509541 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,18 +38,46 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) - val batchInfos = collector.batchInfos - batchInfos should have size 4 - batchInfos.foreach(info => { + // SPARK-6766: batch info should be submitted + val batchInfosSubmitted = collector.batchInfosSubmitted + batchInfosSubmitted should have size 4 + + batchInfosSubmitted.foreach(info => { + info.schedulingDelay should be (None) + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + + // SPARK-6766: processingStartTime of batch info should not be None when starting + val batchInfosStarted = collector.batchInfosStarted + batchInfosStarted should have size 4 + + batchInfosStarted.foreach(info => { + info.schedulingDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + + // test onBatchCompleted + val batchInfosCompleted = collector.batchInfosCompleted + batchInfosCompleted should have size 4 + + batchInfosCompleted.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -58,9 +86,9 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -99,9 +127,20 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfos = new ArrayBuffer[BatchInfo] + val batchInfosCompleted = new ArrayBuffer[BatchInfo] + val batchInfosStarted = new ArrayBuffer[BatchInfo] + val batchInfosSubmitted = new ArrayBuffer[BatchInfo] + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { + batchInfosSubmitted += batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { + batchInfosStarted += batchStarted.batchInfo + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfos += batchCompleted.batchInfo + batchInfosCompleted += batchCompleted.batchInfo } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3565d621e8a6c..c3cae8aeb6d15 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) ssc.awaitTerminationOrTimeout(50) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 87a0395efbf2a..998426ebb82e5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark._ /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16c..c39ad05f41520 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577dd..c3602a5b73732 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 4150b60635ed6..7865b06c2e3c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase { val receiverTracker = ssc.scheduler.receiverTracker // Get the blocks belonging to a batch - def getBlocksOfBatch(batchTime: Long) = { + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala new file mode 100644 index 0000000000000..94b1985116feb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.scalatest.Matchers + +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} + +class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + override def batchDuration: Duration = Milliseconds(100) + + test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + + "onReceiverStarted, onReceiverError, onReceiverStopped") { + val ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + listener.waitingBatches should be (List(batchInfoSubmitted)) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (0) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (List(batchInfoStarted)) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (600) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (List(batchInfoCompleted)) + listener.lastCompletedBatch should be (Some(batchInfoCompleted)) + listener.numUnprocessedBatches should be (0) + listener.numTotalCompletedBatches should be (1) + listener.numTotalProcessedRecords should be (600) + listener.numTotalReceivedRecords should be (600) + + // onReceiverStarted + val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (None) + + // onReceiverError + val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (None) + + // onReceiverStopped + val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (Some(receiverInfoStopped)) + listener.receiverInfo(3) should be (None) + } + + test("Remove the old completed batches when exceeding the limit") { + val ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + + for(_ <- 0 until (limit + 10)) { + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.retainedCompletedBatches.size should be (limit) + listener.numTotalCompletedBatches should be(limit + 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8335659667f22..a3919c43b95b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -291,7 +291,7 @@ object WriteAheadLogSuite { manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74d..d66750463033a 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 24a1e02795218..c357b7ae9d4da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -223,6 +223,7 @@ private[spark] class ApplicationMaster( val appId = client.getAttemptId().getApplicationId().toString() val historyAddress = sparkConf.getOption("spark.yarn.historyServer.address") + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } .getOrElse("") @@ -295,7 +296,7 @@ private[spark] class ApplicationMaster( // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") // must be <= expiryInterval / 2. val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) @@ -378,7 +379,8 @@ private[spark] class ApplicationMaster( logWarning( "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") } - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", waitTries.getOrElse(100000L)) + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", + s"${waitTries.getOrElse(100000L)}ms") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { @@ -403,8 +405,8 @@ private[spark] class ApplicationMaster( // Spark driver should already be up since it launched us, but we don't want to // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", 100000L) - val deadline = System.currentTimeMillis + totalWaitTime + val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val deadline = System.currentTimeMillis + totalWaitTimeMs while (!driverUp && !finished && System.currentTimeMillis < deadline) { try { @@ -469,6 +471,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae7..ae6dc1094d724 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null + var primaryRFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 @@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--py-files") :: value :: tail => pyFiles = value args = tail @@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 79d55a09eb671..1091ff54b0463 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -22,17 +22,21 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -40,6 +44,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} @@ -219,6 +224,7 @@ private[spark] class Client( val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) + obtainTokenForHiveMetastore(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -490,6 +496,12 @@ private[spark] class Client( } else { Nil } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName @@ -499,12 +511,15 @@ private[spark] class Client( if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) @@ -561,7 +576,14 @@ private[spark] class Client( var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) - val report = getApplicationReport(appId) + val report: ApplicationReport = + try { + getApplicationReport(appId) + } catch { + case e: ApplicationNotFoundException => + logError(s"Application $appId not found.") + return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + } val state = report.getYarnApplicationState if (logApplicationReport) { @@ -919,6 +941,64 @@ object Client extends Logging { } } + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null) + + val hiveConf = hiveClass.getMethod("getConf").invoke(hive) + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + + val hiveConfGet = (param:String) => Option(hiveConfClass + .getMethod("get", classOf[java.lang.String]) + .invoke(hiveConf, param)) + + val metastore_uri = hiveConfGet("hive.metastore.uris") + + // Check for local metastore + if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val metastore_kerberos_principal_conf_var = mirror.classLoader + .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") + .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString + + val principal = hiveConfGet(metastore_kerberos_principal_conf_var) + + val username = Option(UserGroupInformation.getCurrentUser().getUserName) + if (principal != None && username != None) { + val tokenStr = hiveClass.getMethod("getDelegationToken", + classOf[java.lang.String], classOf[java.lang.String]) + .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] + + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + credentials.addToken(new Text("hive.server2.delegation.token"),hive2Token) + logDebug("Added hive.Server2.delegation.token to conf.") + hiveClass.getMethod("closeCurrent").invoke(null) + } else { + logError("Username or principal == NULL") + logError(s"""username=${username.getOrElse("(NULL)")}""") + logError(s"""principal=${principal.getOrElse("(NULL)")}""") + throw new IllegalArgumentException("username and/or principal is equal to null!") + } + } else { + logDebug("HiveMetaStore configured in localmode") + } + } catch { + case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e:Exception => { logError("Unexpected Exception " + e) + throw new RuntimeException("Unexpected exception", e) + } + } + } + } + /** * Return whether the two file systems are the same. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf341..da6798cb1b279 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userClass: String = null var pyFiles: String = null var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 1ce10d906ab23..b06069c07f451 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -290,10 +290,19 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + // Add log urls sys.env.get("SPARK_USER").foreach { user => - val baseUrl = "http://%s/node/containerlogs/%s/%s" - .format(container.getNodeHttpAddress, ConverterUtils.toString(container.getId), user) + val containerId = ConverterUtils.toString(container.getId) + val address = container.getNodeHttpAddress + val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 8abdc26b43806..99c05329b4d73 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - @volatile private var stopping: Boolean = false + private var monitorThread: Thread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -57,7 +57,8 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.submitApplication() waitForApplication() - asyncMonitorApplication() + monitorThread = asyncMonitorApplication() + monitorThread.start() } /** @@ -123,34 +124,22 @@ private[spark] class YarnClientSchedulerBackend( * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Unit = { + private def asyncMonitorApplication(): Thread = { assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { - while (!stopping) { - var state: YarnApplicationState = null - try { - val report = client.getApplicationReport(appId) - state = report.getYarnApplicationState() - } catch { - case e: ApplicationNotFoundException => - state = YarnApplicationState.KILLED - } - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.KILLED || - state == YarnApplicationState.FAILED) { - logError(s"Yarn application has already exited with state $state!") - sc.stop() - stopping = true - } - Thread.sleep(1000L) + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") } - Thread.currentThread().interrupt() } } t.setName("Yarn application state monitor") t.setDaemon(true) - t.start() + t } /** @@ -158,7 +147,7 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - stopping = true + monitorThread.interrupt() super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index aab41fa49430f..6b8a5dbf6373e 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 92f04b4b859b3..c1b94ac9c5bdd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -232,19 +232,26 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { testCode(conf) } - def newEnv = MutableHashMap[String, String]() + def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;|") + def classpath(env: MutableHashMap[String, String]): Array[String] = + env(Environment.CLASSPATH.name).split(":|;|") - def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] = + (a ++ b).flatten.toArray - def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = - Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = { + Try(clazz.getField(field)) + .map(_.get(null).asInstanceOf[A]) + .toOption + .map(mapTo) + .getOrElse(defaults) + } def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index c09b01bafce37..455f1019d86dd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -79,7 +79,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { - override def equals(other: Any) = false + override def equals(other: Any): Boolean = false } def createAllocator(maxExecutors: Int = 5): YarnAllocator = { @@ -118,7 +118,9 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + + val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size + size should be (0) } test("some containers allocated") { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 4194f36499e66..9395316b71ff4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -46,7 +46,7 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { logWarning("Cannot execute bash, skipping bash tests.") } - def bashTest(name: String)(fn: => Unit) = + def bashTest(name: String)(fn: => Unit): Unit = if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") {