diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 1281c41213e32..cf2e5ddeb7a9d 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) {
con
}
-launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) {
+determineSparkSubmitBin <- function() {
if (.Platform$OS.type == "unix") {
sparkSubmitBinName = "spark-submit"
} else {
sparkSubmitBinName = "spark-submit.cmd"
}
+ sparkSubmitBinName
+}
+
+generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
+ if (jars != "") {
+ jars <- paste("--jars", jars)
+ }
+
+ if (packages != "") {
+ packages <- paste("--packages", packages)
+ }
+ combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ")
+ combinedArgs
+}
+
+launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
+ sparkSubmitBin <- determineSparkSubmitBin()
if (sparkHome != "") {
sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName)
} else {
sparkSubmitBin <- sparkSubmitBinName
}
-
- if (jars != "") {
- jars <- paste("--jars", jars)
- }
-
- combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ")
+ combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages)
cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n")
invisible(system2(sparkSubmitBin, combinedArgs, wait = F))
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index dbde0c44c55d5..8f81d5640c1d0 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -81,6 +81,7 @@ sparkR.stop <- function() {
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
#' @param sparkRLibDir The path where R is installed on the worker nodes.
+#' @param sparkPackages Character string vector of packages from spark-packages.org
#' @export
#' @examples
#'\dontrun{
@@ -100,7 +101,8 @@ sparkR.init <- function(
sparkEnvir = list(),
sparkExecutorEnv = list(),
sparkJars = "",
- sparkRLibDir = "") {
+ sparkRLibDir = "",
+ sparkPackages = "") {
if (exists(".sparkRjsc", envir = .sparkREnv)) {
cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
@@ -129,7 +131,8 @@ sparkR.init <- function(
args = path,
sparkHome = sparkHome,
jars = jars,
- sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"))
+ sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
+ sparkPackages = sparkPackages)
# wait atmost 100 seconds for JVM to launch
wait <- 0.1
for (i in 1:25) {
diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R
new file mode 100644
index 0000000000000..30b05c1a2afcd
--- /dev/null
+++ b/R/pkg/inst/tests/test_client.R
@@ -0,0 +1,32 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("functions in client.R")
+
+test_that("adding spark-testing-base as a package works", {
+ args <- generateSparkSubmitArgs("", "", "", "",
+ "holdenk:spark-testing-base:1.3.0_0.0.5")
+ expect_equal(gsub("[[:space:]]", "", args),
+ gsub("[[:space:]]", "",
+ "--packages holdenk:spark-testing-base:1.3.0_0.0.5"))
+})
+
+test_that("no package specified doesn't add packages flag", {
+ args <- generateSparkSubmitArgs("", "", "", "", "")
+ expect_equal(gsub("[[:space:]]", "", args),
+ "")
+})
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 597d46a3d2223..9d8e7e9f03aea 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -17,29 +17,29 @@
package org.apache.spark.shuffle.hash
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.util.{Failure, Success, Try}
+import java.io.InputStream
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.{Failure, Success}
import org.apache.spark._
-import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
+ ShuffleBlockId}
private[hash] object BlockStoreShuffleFetcher extends Logging {
- def fetch[T](
+ def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer)
- : Iterator[T] =
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker)
+ : Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
- val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
+ val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
@@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.shuffleClient,
+ blockManager,
+ blocksByAddress,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+
+ // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
+ blockFetcherItr.map { blockPair =>
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Success(block) => {
- block.asInstanceOf[Iterator[T]]
+ case Success(inputStream) => {
+ (blockId, inputStream)
}
case Failure(e) => {
blockId match {
@@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
}
-
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
- context,
- SparkEnv.get.blockManager.shuffleClient,
- blockManager,
- blocksByAddress,
- serializer,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
- val itr = blockFetcherItr.flatMap(unpackBlock)
-
- val completionIter = CompletionIterator[T, Iterator[T]](itr, {
- context.taskMetrics.updateShuffleReadMetrics()
- })
-
- new InterruptibleIterator[T](context, completionIter) {
- val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
- override def next(): T = {
- readMetrics.incRecordsRead(1)
- delegate.next()
- }
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 41bafabde05b9..d5c9880659dd3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,16 +17,20 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, TaskContext}
+import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
- context: TaskContext)
+ context: TaskContext,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
@@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
+ val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
+ handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+
+ // Wrap the streams for compression based on configuration
+ val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+ blockManager.wrapForCompression(blockId, inputStream)
+ }
+
val ser = Serializer.getSerializer(dep.serializer)
- val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
+ val serializerInstance = ser.newInstance()
+
+ // Create a key/value iterator for each stream
+ val recordIter = wrappedStreams.flatMap { wrappedStream =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map(record => {
+ readMetrics.incRecordsRead(1)
+ record
+ }),
+ context.taskMetrics().updateShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
- new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
- new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
-
- // Convert the Product2s to pairs since this is what downstream RDDs currently expect
- iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d0faab62c9e9e..e49e39679e940 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,23 +17,23 @@
package org.apache.spark.storage
+import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.{Failure, Try}
import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
-import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.util.Utils
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
- * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
- * pipelined fashion as they are received.
+ * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
* using too much memory.
@@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
- * @param serializer serializer used to deserialize the data.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
*/
private[spark]
@@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[InputStream])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(
/**
* A queue to hold our results. This turns the asynchronous model provided by
- * [[BlockTransferService]] into a synchronous model (iterator).
+ * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
@@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
- private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-
- private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
+ private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(
initialize()
- /**
- * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
- */
- private[this] def cleanup() {
- isZombie = true
+ // Decrements the buffer reference count.
+ // The currentResult is set to null to prevent releasing the buffer again on cleanup()
+ private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, buf) => buf.release()
case _ =>
}
+ currentResult = null
+ }
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ releaseCurrentResultBuffer()
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
@@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Try[Iterator[Any]]) = {
+ /**
+ * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+ * underlying each InputStream will be freed by the cleanup() method registered with the
+ * TaskCompletionListener. However, callers should close() these InputStreams
+ * as soon as they are no longer needed, in order to release memory as early as possible.
+ */
+ override def next(): (BlockId, Try[InputStream]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
@@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorTry: Try[Iterator[Any]] = result match {
+ val iteratorTry: Try[InputStream] = result match {
case FailureFetchResult(_, e) =>
Failure(e)
case SuccessFetchResult(blockId, _, buf) =>
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
// not exist, SPARK-4085). In that case, we should propagate the right exception so
// the scheduler gets a FetchFailedException.
- Try(buf.createInputStream()).map { is0 =>
- val is = blockManager.wrapForCompression(blockId, is0)
- val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
- CompletionIterator[Any, Iterator[Any]](iter, {
- // Once the iterator is exhausted, release the buffer and set currentResult to null
- // so we don't release it again in cleanup.
- currentResult = null
- buf.release()
- })
+ Try(buf.createInputStream()).map { inputStream =>
+ new BufferReleasingInputStream(inputStream, this)
}
}
@@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator(
}
}
+/**
+ * Helper class that ensures a ManagedBuffer is release upon InputStream.close()
+ */
+private class BufferReleasingInputStream(
+ private val delegate: InputStream,
+ private val iterator: ShuffleBlockFetcherIterator)
+ extends InputStream {
+ private[this] var closed = false
+
+ override def read(): Int = delegate.read()
+
+ override def close(): Unit = {
+ if (!closed) {
+ delegate.close()
+ iterator.releaseCurrentResultBuffer()
+ closed = true
+ }
+ }
+
+ override def available(): Int = delegate.available()
+
+ override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
+
+ override def skip(n: Long): Long = delegate.skip(n)
+
+ override def markSupported(): Boolean = delegate.markSupported()
+
+ override def read(b: Array[Byte]): Int = delegate.read(b)
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
+
+ override def reset(): Unit = delegate.reset()
+}
private[storage]
object ShuffleBlockFetcherIterator {
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 063e2a1f8b18e..e2d25e36365fa 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -35,6 +35,10 @@ private[spark] object ToolTips {
val OUTPUT = "Bytes and records written to Hadoop."
+ val STORAGE_MEMORY =
+ "Memory used / total available memory for storage of data " +
+ "like RDD partitions cached in memory. "
+
val SHUFFLE_WRITE =
"Bytes and records written to disk in order to be read by a shuffle in a future stage."
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index b247e4cdc3bd4..01cddda4c62cd 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -67,7 +67,7 @@ private[ui] class ExecutorsPage(
The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster.
You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name
-etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the
-SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should
-already be created for you.
+, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`,
+which can be created from the SparkContext. If you are working from the SparkR shell, the
+`SQLContext` and `SparkContext` should already be created for you.
{% highlight r %}
sc <- sparkR.init()
@@ -62,7 +62,16 @@ head(df)
SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
-The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro).
+The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by
+specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init`
+you can specify the packages with the `packages` argument.
+
+
+{% highlight r %}
+sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3")
+sqlContext <- sparkRSQL.init(sc)
+{% endhighlight %}
+
We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 63e2c79669763..e4932cfa7a4fc 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -306,6 +306,13 @@ def parse_args():
"--private-ips", action="store_true", default=False,
help="Use private IPs for instances rather than public if VPC/subnet " +
"requires that.")
+ parser.add_option(
+ "--instance-initiated-shutdown-behavior", default="stop",
+ choices=["stop", "terminate"],
+ help="Whether instances should terminate when shut down or just stop")
+ parser.add_option(
+ "--instance-profile-name", default=None,
+ help="IAM profile name to launch instances under")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -602,7 +609,8 @@ def launch_cluster(conn, opts, cluster_name):
block_device_map=block_map,
subnet_id=opts.subnet_id,
placement_group=opts.placement_group,
- user_data=user_data_content)
+ user_data=user_data_content,
+ instance_profile_name=opts.instance_profile_name)
my_req_ids += [req.id for req in slave_reqs]
i += 1
@@ -647,16 +655,19 @@ def launch_cluster(conn, opts, cluster_name):
for zone in zones:
num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
if num_slaves_this_zone > 0:
- slave_res = image.run(key_name=opts.key_pair,
- security_group_ids=[slave_group.id] + additional_group_ids,
- instance_type=opts.instance_type,
- placement=zone,
- min_count=num_slaves_this_zone,
- max_count=num_slaves_this_zone,
- block_device_map=block_map,
- subnet_id=opts.subnet_id,
- placement_group=opts.placement_group,
- user_data=user_data_content)
+ slave_res = image.run(
+ key_name=opts.key_pair,
+ security_group_ids=[slave_group.id] + additional_group_ids,
+ instance_type=opts.instance_type,
+ placement=zone,
+ min_count=num_slaves_this_zone,
+ max_count=num_slaves_this_zone,
+ block_device_map=block_map,
+ subnet_id=opts.subnet_id,
+ placement_group=opts.placement_group,
+ user_data=user_data_content,
+ instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
+ instance_profile_name=opts.instance_profile_name)
slave_nodes += slave_res.instances
print("Launched {s} slave{plural_s} in {z}, regid = {r}".format(
s=num_slaves_this_zone,
@@ -678,16 +689,19 @@ def launch_cluster(conn, opts, cluster_name):
master_type = opts.instance_type
if opts.zone == 'all':
opts.zone = random.choice(conn.get_all_zones()).name
- master_res = image.run(key_name=opts.key_pair,
- security_group_ids=[master_group.id] + additional_group_ids,
- instance_type=master_type,
- placement=opts.zone,
- min_count=1,
- max_count=1,
- block_device_map=block_map,
- subnet_id=opts.subnet_id,
- placement_group=opts.placement_group,
- user_data=user_data_content)
+ master_res = image.run(
+ key_name=opts.key_pair,
+ security_group_ids=[master_group.id] + additional_group_ids,
+ instance_type=master_type,
+ placement=opts.zone,
+ min_count=1,
+ max_count=1,
+ block_device_map=block_map,
+ subnet_id=opts.subnet_id,
+ placement_group=opts.placement_group,
+ user_data=user_data_content,
+ instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
+ instance_profile_name=opts.instance_profile_name)
master_nodes = master_res.instances
print("Launched master in %s, regid = %s" % (zone, master_res.id))
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f678c69a6dfa9..6f86a505b3ae4 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -69,7 +69,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.parquet.CatalystTimestampConverter"),
ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.parquet.CatalystTimestampConverter$")
+ "org.apache.spark.sql.parquet.CatalystTimestampConverter$"),
+ // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.ParquetTypeInfo"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.parquet.ParquetTypeInfo$")
)
case v if v.startsWith("1.4") =>
Seq(
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 758accf4b41eb..2698f10d06883 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,6 +21,7 @@
from numpy import array
from pyspark import RDD
+from pyspark.streaming import DStream
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
@@ -28,7 +29,8 @@
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
- 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes',
+ 'StreamingLogisticRegressionWithSGD']
class LinearClassificationModel(LinearModel):
@@ -583,6 +585,98 @@ def train(cls, data, lambda_=1.0):
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
+class StreamingLinearAlgorithm(object):
+ """
+ Base class that has to be inherited by any StreamingLinearAlgorithm.
+
+ Prevents reimplementation of methods predictOn and predictOnValues.
+ """
+ def __init__(self, model):
+ self._model = model
+
+ def latestModel(self):
+ """
+ Returns the latest model.
+ """
+ return self._model
+
+ def _validate(self, dstream):
+ if not isinstance(dstream, DStream):
+ raise TypeError(
+ "dstream should be a DStream object, got %s" % type(dstream))
+ if not self._model:
+ raise ValueError(
+ "Model must be intialized using setInitialWeights")
+
+ def predictOn(self, dstream):
+ """
+ Make predictions on a dstream.
+
+ :return: Transformed dstream object.
+ """
+ self._validate(dstream)
+ return dstream.map(lambda x: self._model.predict(x))
+
+ def predictOnValues(self, dstream):
+ """
+ Make predictions on a keyed dstream.
+
+ :return: Transformed dstream object.
+ """
+ self._validate(dstream)
+ return dstream.mapValues(lambda x: self._model.predict(x))
+
+
+@inherit_doc
+class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
+ """
+ Run LogisticRegression with SGD on a stream of data.
+
+ The weights obtained at the end of training a stream are used as initial
+ weights for the next stream.
+
+ :param stepSize: Step size for each iteration of gradient descent.
+ :param numIterations: Number of iterations run for each batch of data.
+ :param miniBatchFraction: Fraction of data on which SGD is run for each
+ iteration.
+ :param regParam: L2 Regularization parameter.
+ """
+ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01):
+ self.stepSize = stepSize
+ self.numIterations = numIterations
+ self.regParam = regParam
+ self.miniBatchFraction = miniBatchFraction
+ self._model = None
+ super(StreamingLogisticRegressionWithSGD, self).__init__(
+ model=self._model)
+
+ def setInitialWeights(self, initialWeights):
+ """
+ Set the initial value of weights.
+
+ This must be set before running trainOn and predictOn.
+ """
+ initialWeights = _convert_to_vector(initialWeights)
+
+ # LogisticRegressionWithSGD does only binary classification.
+ self._model = LogisticRegressionModel(
+ initialWeights, 0, initialWeights.size, 2)
+ return self
+
+ def trainOn(self, dstream):
+ """Train the model on the incoming dstream."""
+ self._validate(dstream)
+
+ def update(rdd):
+ # LogisticRegressionWithSGD.train raises an error for an empty RDD.
+ if not rdd.isEmpty():
+ self._model = LogisticRegressionWithSGD.train(
+ rdd, self.numIterations, self.stepSize,
+ self.miniBatchFraction, self._model.weights)
+
+ dstream.foreachRDD(update)
+
+
def _test():
import doctest
from pyspark import SparkContext
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 577ecc947174f..d36887392cbb1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -26,7 +26,8 @@
from time import time, sleep
from shutil import rmtree
-from numpy import array, array_equal, zeros, inf, all, random
+from numpy import (
+ array, array_equal, zeros, inf, random, exp, dot, all, mean)
from numpy import sum as array_sum
from py4j.protocol import Py4JJavaError
@@ -45,6 +46,7 @@
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
@@ -1038,6 +1040,137 @@ def test_dim(self):
self.assertEqual(len(point.features), 2)
+class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
+
+ @staticmethod
+ def generateLogisticInput(offset, scale, nPoints, seed):
+ """
+ Generate 1 / (1 + exp(-x * scale + offset))
+
+ where,
+ x is randomnly distributed and the threshold
+ and labels for each sample in x is obtained from a random uniform
+ distribution.
+ """
+ rng = random.RandomState(seed)
+ x = rng.randn(nPoints)
+ sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset)))
+ y_p = rng.rand(nPoints)
+ cut_off = y_p <= sigmoid
+ y_p[cut_off] = 1.0
+ y_p[~cut_off] = 0.0
+ return [
+ LabeledPoint(y_p[i], Vectors.dense([x[i]]))
+ for i in range(nPoints)]
+
+ def test_parameter_accuracy(self):
+ """
+ Test that the final value of weights is close to the desired value.
+ """
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ input_stream = self.ssc.queueStream(input_batches)
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+ slr.trainOn(input_stream)
+
+ t = time()
+ self.ssc.start()
+ self._ssc_wait(t, 20.0, 0.01)
+ rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5
+ self.assertAlmostEqual(rel, 0.1, 1)
+
+ def test_convergence(self):
+ """
+ Test that weights converge to the required value on toy data.
+ """
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ input_stream = self.ssc.queueStream(input_batches)
+ models = []
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+ slr.trainOn(input_stream)
+ input_stream.foreachRDD(
+ lambda x: models.append(slr.latestModel().weights[0]))
+
+ t = time()
+ self.ssc.start()
+ self._ssc_wait(t, 15.0, 0.01)
+ t_models = array(models)
+ diff = t_models[1:] - t_models[:-1]
+
+ # Test that weights improve with a small tolerance,
+ self.assertTrue(all(diff >= -0.1))
+ self.assertTrue(array_sum(diff > 0) > 1)
+
+ @staticmethod
+ def calculate_accuracy_error(true, predicted):
+ return sum(abs(array(true) - array(predicted))) / len(true)
+
+ def test_predictions(self):
+ """Test predicted values on a toy model."""
+ input_batches = []
+ for i in range(20):
+ batch = self.sc.parallelize(
+ self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ input_batches.append(batch.map(lambda x: (x.label, x.features)))
+ input_stream = self.ssc.queueStream(input_batches)
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([1.5])
+ predict_stream = slr.predictOnValues(input_stream)
+ true_predicted = []
+ predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect()))
+ t = time()
+ self.ssc.start()
+ self._ssc_wait(t, 5.0, 0.01)
+
+ # Test that the accuracy error is no more than 0.4 on each batch.
+ for batch in true_predicted:
+ true, predicted = zip(*batch)
+ self.assertTrue(
+ self.calculate_accuracy_error(true, predicted) < 0.4)
+
+ def test_training_and_prediction(self):
+ """Test that the model improves on toy data with no. of batches"""
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ predict_batches = [
+ b.map(lambda lp: (lp.label, lp.features)) for b in input_batches]
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.01, numIterations=25)
+ slr.setInitialWeights([-0.1])
+ errors = []
+
+ def collect_errors(rdd):
+ true, predicted = zip(*rdd.collect())
+ errors.append(self.calculate_accuracy_error(true, predicted))
+
+ true_predicted = []
+ input_stream = self.ssc.queueStream(input_batches)
+ predict_stream = self.ssc.queueStream(predict_batches)
+ slr.trainOn(input_stream)
+ ps = slr.predictOnValues(predict_stream)
+ ps.foreachRDD(lambda x: collect_errors(x))
+
+ t = time()
+ self.ssc.start()
+ self._ssc_wait(t, 20.0, 0.01)
+
+ # Test that the improvement in error is atleast 0.3
+ self.assertTrue(errors[1] - errors[-1] > 0.3)
+
+
class MLUtilsTests(MLlibTestCase):
def test_append_bias(self):
data = [2.0, 2.0, 2.0]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0a3f5a7b5cade..117c87a785fdb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -283,7 +283,7 @@ class Analyzer(
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
- val (oldRelation, newRelation) = right.collect {
+ right.collect {
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
@@ -308,25 +308,27 @@ class Analyzer(
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
- }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass.
- sys.error(
- s"""
- |Failure when resolving conflicting references in Join:
- |$plan
- |
- |Conflicting attributes: ${conflictingAttributes.mkString(",")}
- """.stripMargin)
}
-
- val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
- val newRight = right transformUp {
- case r if r == oldRelation => newRelation
- } transformUp {
- case other => other transformExpressions {
- case a: Attribute => attributeRewrites.get(a).getOrElse(a)
- }
+ // Only handle first case, others will be fixed on the next pass.
+ .headOption match {
+ case None =>
+ /*
+ * No result implies that there is a logical plan node that produces new references
+ * that this rule cannot handle. When that is the case, there must be another rule
+ * that resolves these conflicts. Otherwise, the analysis will fail.
+ */
+ j
+ case Some((oldRelation, newRelation)) =>
+ val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
+ val newRight = right transformUp {
+ case r if r == oldRelation => newRelation
+ } transformUp {
+ case other => other transformExpressions {
+ case a: Attribute => attributeRewrites.get(a).getOrElse(a)
+ }
+ }
+ j.copy(right = newRight)
}
- j.copy(right = newRight)
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
@@ -585,8 +587,8 @@ class Analyzer(
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
|but only single name '${name}' specified""".stripMargin)
- case Alias(g: Generator, name) => Some((g, name :: Nil))
- case MultiAlias(g: Generator, names) => Some(g, names)
+ case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
+ case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
case _ => None
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c5a1437be6d05..a069b4710f38c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -48,6 +48,7 @@ trait CheckAnalysis {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
+
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
@@ -121,6 +122,17 @@ trait CheckAnalysis {
case _ => // Analysis successful!
}
+
+ // Special handling for cases when self-join introduce duplicate expression ids.
+ case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
+ val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ failAnalysis(
+ s"""
+ |Failure when resolving conflicting references in Join:
+ |$plan
+ |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+ |""".stripMargin)
+
}
extendedCheckRules.foreach(_(plan))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index d4ab1fc643c33..976fa57cb98d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -317,6 +317,7 @@ trait HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
+ case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
}
}
@@ -590,11 +591,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ CreateArray(children) if !a.resolved =>
- val commonType = a.childTypes.reduce(
- (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
- CreateArray(
- children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
+ case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonTypeAndPromoteToString(types) match {
+ case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
+ case None => a
+ }
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
@@ -620,12 +622,11 @@ trait HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
- case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
+ case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
- case None =>
- sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
+ case None => c
}
}
}
@@ -677,8 +678,8 @@ trait HiveTypeCoercion {
findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType))
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
- case Seq(when, then) if when.dataType != commonType =>
- Seq(Cast(when, commonType), then)
+ case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType =>
+ Seq(Cast(whenExpr, commonType), thenExpr)
case other => other
}.reduce(_ ++ _)
CaseKeyWhen(Cast(c.key, commonType), castedBranches)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d271434a306dd..8bd7fc18a8dd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (resolve(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"cannot cast ${child.dataType} to $dataType")
+ }
+ }
override def foldable: Boolean = child.foldable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a10a959ae766f..f59db3d5dfc23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
- * Note: it's not valid to call this method until `childrenResolved == true`
- * TODO: we should remove the default implementation and implement it for all
- * expressions with proper error message.
+ * Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4d6c1c265150d..4d7c95ffd1850 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -96,6 +96,11 @@ object ExtractValue {
}
}
+/**
+ * A common interface of all kinds of extract value expressions.
+ * Note: concrete extract value expressions are created only by `ExtractValue.apply`,
+ * we don't need to do type check for them.
+ */
trait ExtractValue extends UnaryExpression {
self: Product =>
}
@@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
- override lazy val resolved = childrenResolved &&
- child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
-
protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
@@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
- override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
-
protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 00d2e499c5890..a9fc54c548f49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog
-import org.apache.spark.sql.catalyst
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MinFunction = new MinFunction(child, this)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MaxFunction = new MaxFunction(child, this)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance(): CountFunction = new CountFunction(child, this)
}
+case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var count: Long = _
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ count += 1L
+ }
+ }
+
+ override def eval(input: InternalRow): Any = count
+}
+
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)
@@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
}
}
+case class CountDistinctFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = seen.size.toLong
+}
+
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)
@@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
}
}
+case class ApproxCountDistinctPartitionFunction(
+ expr: Expression,
+ base: AggregateExpression,
+ relativeSD: Double)
+ extends AggregateFunction {
+ def this() = this(null, null, 0) // Required for serialization.
+
+ private val hyperLogLog = new HyperLogLog(relativeSD)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ hyperLogLog.offer(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = hyperLogLog
+}
+
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
@@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
}
}
+case class ApproxCountDistinctMergeFunction(
+ expr: Expression,
+ base: AggregateExpression,
+ relativeSD: Double)
+ extends AggregateFunction {
+ def this() = this(null, null, 0) // Required for serialization.
+
+ private val hyperLogLog = new HyperLogLog(relativeSD)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
+ }
+
+ override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
+}
+
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -349,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
override def newInstance(): AverageFunction = new AverageFunction(child, this)
-}
-
-case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- child.dataType
- }
-
- override def toString: String = s"SUM($child)"
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(_, _) =>
- val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
- SplitEvaluation(
- Cast(CombineSum(partialSum.toAttribute), dataType),
- partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- CombineSum(partialSum.toAttribute),
- partialSum :: Nil)
- }
- }
-
- override def newInstance(): SumFunction = new SumFunction(child, this)
-}
-
-/**
- * Sum should satisfy 3 cases:
- * 1) sum of all null values = zero
- * 2) sum for table column with no data = null
- * 3) sum of column with null and not null values = sum of not null values
- * Require separate CombineSum Expression and function as it has to distinguish "No data" case
- * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
- * Combining PartitionLevel InputData
- * <-- null
- * Zero <-- Zero <-- null
- *
- * <-- null <-- no data
- * null <-- null <-- no data
- */
-case class CombineSum(child: Expression) extends AggregateExpression {
- def this() = this(null)
-
- override def children: Seq[Expression] = child :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"CombineSum($child)"
- override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
-}
-
-case class SumDistinct(child: Expression)
- extends PartialAggregate with trees.UnaryNode[Expression] {
-
- def this() = this(null)
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- child.dataType
- }
- override def toString: String = s"SUM(DISTINCT $child)"
- override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
- SplitEvaluation(
- CombineSetsAndSum(partialSet.toAttribute, this),
- partialSet :: Nil)
- }
-}
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
- def this() = this(null, null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = base.dataType
- override def toString: String = s"CombineAndSum($inputSet)"
- override def newInstance(): CombineSetsAndSumFunction = {
- new CombineSetsAndSumFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndSumFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
- if (casted.size == 0) {
- null
- } else {
- Cast(Literal(
- casted.iterator.map(f => f.apply(0)).reduceLeft(
- base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- base.dataType).eval(null)
- }
- }
-}
-
-case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"FIRST($child)"
-
- override def asPartial: SplitEvaluation = {
- val partialFirst = Alias(First(child), "PartialFirst")()
- SplitEvaluation(
- First(partialFirst.toAttribute),
- partialFirst :: Nil)
- }
- override def newInstance(): FirstFunction = new FirstFunction(child, this)
-}
-
-case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references: AttributeSet = child.references
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"LAST($child)"
-
- override def asPartial: SplitEvaluation = {
- val partialLast = Alias(Last(child), "PartialLast")()
- SplitEvaluation(
- Last(partialLast.toAttribute),
- partialLast :: Nil)
- }
- override def newInstance(): LastFunction = new LastFunction(child, this)
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -551,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
- def this() = this(null, null) // Required for serialization.
+case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- var count: Long = _
+ override def nullable: Boolean = true
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1L
- }
+ override def dataType: DataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
}
- override def eval(input: InternalRow): Any = count
-}
-
-case class ApproxCountDistinctPartitionFunction(
- expr: Expression,
- base: AggregateExpression,
- relativeSD: Double)
- extends AggregateFunction {
- def this() = this(null, null, 0) // Required for serialization.
+ override def toString: String = s"SUM($child)"
- private val hyperLogLog = new HyperLogLog(relativeSD)
+ override def asPartial: SplitEvaluation = {
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ SplitEvaluation(
+ Cast(CombineSum(partialSum.toAttribute), dataType),
+ partialSum :: Nil)
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- hyperLogLog.offer(evaluatedExpr)
+ case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ SplitEvaluation(
+ CombineSum(partialSum.toAttribute),
+ partialSum :: Nil)
}
}
- override def eval(input: InternalRow): Any = hyperLogLog
-}
-
-case class ApproxCountDistinctMergeFunction(
- expr: Expression,
- base: AggregateExpression,
- relativeSD: Double)
- extends AggregateFunction {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
- }
+ override def newInstance(): SumFunction = new SumFunction(child, this)
- override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -633,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
}
+/**
+ * Sum should satisfy 3 cases:
+ * 1) sum of all null values = zero
+ * 2) sum for table column with no data = null
+ * 3) sum of column with null and not null values = sum of not null values
+ * Require separate CombineSum Expression and function as it has to distinguish "No data" case
+ * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
+ * Combining PartitionLevel InputData
+ * <-- null
+ * Zero <-- Zero <-- null
+ *
+ * <-- null <-- no data
+ * null <-- null <-- no data
+ */
+case class CombineSum(child: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children: Seq[Expression] = child :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"CombineSum($child)"
+ override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
+}
+
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
@@ -670,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}
+case class SumDistinct(child: Expression)
+ extends PartialAggregate with trees.UnaryNode[Expression] {
+
+ def this() = this(null)
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+ override def toString: String = s"SUM(DISTINCT $child)"
+ override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
+
+ override def asPartial: SplitEvaluation = {
+ val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
+ SplitEvaluation(
+ CombineSetsAndSum(partialSet.toAttribute, this),
+ partialSet :: Nil)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
+}
+
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
@@ -696,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CountDistinctFunction(
- @transient expr: Seq[Expression],
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+ def this() = this(null, null)
+
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = base.dataType
+ override def toString: String = s"CombineAndSum($inputSet)"
+ override def newInstance(): CombineSetsAndSumFunction = {
+ new CombineSetsAndSumFunction(inputSet, this)
+ }
+}
+
+case class CombineSetsAndSumFunction(
+ @transient inputSet: Expression,
@transient base: AggregateExpression)
extends AggregateFunction {
@@ -705,17 +684,39 @@ case class CountDistinctFunction(
val seen = new OpenHashSet[Any]()
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
+ val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+ val inputIterator = inputSetEval.iterator
+ while (inputIterator.hasNext) {
+ seen.add(inputIterator.next)
}
}
- override def eval(input: InternalRow): Any = seen.size.toLong
+ override def eval(input: InternalRow): Any = {
+ val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
+ if (casted.size == 0) {
+ null
+ } else {
+ Cast(Literal(
+ casted.iterator.map(f => f.apply(0)).reduceLeft(
+ base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+ base.dataType).eval(null)
+ }
+ }
+}
+
+case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"FIRST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialFirst = Alias(First(child), "PartialFirst")()
+ SplitEvaluation(
+ First(partialFirst.toAttribute),
+ partialFirst :: Nil)
+ }
+ override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -732,6 +733,21 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}
+case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def references: AttributeSet = child.references
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"LAST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialLast = Alias(Last(child), "PartialLast")()
+ SplitEvaluation(
+ Last(partialLast.toAttribute),
+ partialLast :: Nil)
+ }
+ override def newInstance(): LastFunction = new LastFunction(child, this)
+}
+
case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ace8427c8ddaf..3d4d9e2d798f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -25,8 +25,6 @@ import org.apache.spark.sql.types._
abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index e0bf07ed182f3..5def57b067424 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
/**
* Returns an Array containing the evaluation of all children expressions.
*/
@@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- lazy val childTypes = children.map(_.dataType).distinct
-
- override lazy val resolved =
- childrenResolved && childTypes.size <= 1
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def dataType: DataType = {
- assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
ArrayType(
- childTypes.headOption.getOrElse(NullType),
+ children.headOption.map(_.dataType).getOrElse(NullType),
containsNull = children.exists(_.nullable))
}
@@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- override lazy val resolved: Boolean = childrenResolved
-
override lazy val dataType: StructType = {
- assert(resolved,
- s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
- val fields = children.zipWithIndex.map { case (child, idx) =>
- child match {
- case ne: NamedExpression =>
- StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
- case _ =>
- StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
- }
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
}
+ }
StructType(fields)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 2bc893af02641..f5c2dde191cf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -17,16 +17,17 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
-/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
+/**
+ * Return the unscaled Long value of a Decimal, assuming it fits in a Long.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
case class UnscaledValue(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def toString: String = s"UnscaledValue($child)"
override def eval(input: InternalRow): Any = {
@@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
}
}
-/** Create a Decimal from an unscaled Long value */
+/**
+ * Create a Decimal from an unscaled Long value.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
override def dataType: DataType = DecimalType(precision, scale)
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
override def eval(input: InternalRow): Decimal = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index f30cb42d12b83..356560e54cae3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@@ -100,9 +100,14 @@ case class UserDefinedGenerator(
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
- override lazy val resolved =
- child.resolved &&
- (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function explode should be array or map type, not ${child.dataType}")
+ }
+ }
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 250564dc4b818..5694afc61be05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.lang.{Long => JLong}
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
@@ -224,7 +222,7 @@ case class Bin(child: Expression)
def funcName: String = name.toLowerCase
- override def eval(input: catalyst.InternalRow): Any = {
+ override def eval(input: InternalRow): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 9cacdceb13837..6f56a9ec7beb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
@@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
- override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
+ override lazy val resolved =
+ childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
override def eval(input: InternalRow): Any = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 98acaf23c44c1..5d5911403ece1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -17,33 +17,32 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.DataType
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
- override def nullable: Boolean = !children.exists(!_.nullable)
+ override def nullable: Boolean = children.forall(_.nullable)
// Coalesce is foldable if all children are foldable.
- override def foldable: Boolean = !children.exists(!_.foldable)
+ override def foldable: Boolean = children.forall(_.foldable)
- // Only resolved if all the children are of the same type.
- override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children == Nil) {
+ TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
+ } else {
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
+ }
+ }
override def toString: String = s"Coalesce(${children.mkString(",")})"
- override def dataType: DataType = if (resolved) {
- children.head.dataType
- } else {
- val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
- throw new UnresolvedException(
- this, s"Coalesce cannot have children of different types. $childTypes")
- }
+ override def dataType: DataType = children.head.dataType
override def eval(input: InternalRow): Any = {
- var i = 0
var result: Any = null
val childIterator = children.iterator
while (childIterator.hasNext && result == null) {
@@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
@@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
}
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 30e41677b774b..efc6f50b78943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
/**
* Adds an item to a set.
* For performance, this expression mutates its input during evaluation.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
@@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
override def nullable: Boolean = set.nullable
- override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
+ override def dataType: DataType = set.dataType
override def eval(input: InternalRow): Any = {
val itemEval = item.eval(input)
@@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
/**
* Combines the elements of two sets.
* For performance, this expression mutates its left input set during evaluation.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
override def nullable: Boolean = left.nullable || right.nullable
- override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
+ override def dataType: DataType = left.dataType
override def symbol: String = "++="
@@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
/**
* Returns the number of elements in the input set.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class CountSet(child: Expression) extends UnaryExpression {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 315c63e63c635..44416e79cd7aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes {
def convert(v: UTF8String): UTF8String
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 896e383f50eac..12023ad311dc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -68,7 +68,8 @@ case class WindowSpecDefinition(
override def children: Seq[Expression] = partitionSpec ++ orderSpec
override lazy val resolved: Boolean =
- childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame]
+ childrenResolved && checkInputDataTypes().isSuccess &&
+ frameSpecification.isInstanceOf[SpecifiedWindowFrame]
override def toString: String = simpleString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 98b4476076854..bfd24287c9645 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer {
Batch("Distinct", FixedPoint(100),
ReplaceDistinctWithAggregate) ::
Batch("Operator Optimizations", FixedPoint(100),
- UnionPushdown,
- CombineFilters,
+ // Operator push down
+ UnionPushDown,
+ PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
ColumnPruning,
+ // Operator combine
ProjectCollapsing,
+ CombineFilters,
CombineLimits,
+ // Constant folding
NullPropagation,
OptimizeIn,
ConstantFolding,
LikeSimplification,
BooleanSimplification,
- PushPredicateThroughJoin,
RemovePositive,
SimplifyFilters,
SimplifyCasts,
@@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer {
}
/**
- * Pushes operations to either side of a Union.
- */
-object UnionPushdown extends Rule[LogicalPlan] {
+ * Pushes operations to either side of a Union.
+ */
+object UnionPushDown extends Rule[LogicalPlan] {
/**
- * Maps Attributes from the left side to the corresponding Attribute on the right side.
- */
- def buildRewrites(union: Union): AttributeMap[Attribute] = {
+ * Maps Attributes from the left side to the corresponding Attribute on the right side.
+ */
+ private def buildRewrites(union: Union): AttributeMap[Attribute] = {
assert(union.left.output.size == union.right.output.size)
AttributeMap(union.left.output.zip(union.right.output))
}
/**
- * Rewrites an expression so that it can be pushed to the right side of a Union operator.
- * This method relies on the fact that the output attributes of a union are always equal
- * to the left child's output.
- */
- def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = {
+ * Rewrites an expression so that it can be pushed to the right side of a Union operator.
+ * This method relies on the fact that the output attributes of a union are always equal
+ * to the left child's output.
+ */
+ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
val result = e transform {
case a: Attribute => rewrites(a)
}
@@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
}
}
-
/**
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
* transformations:
@@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
* - Aggregate
* - Project <- Join
* - LeftSemiJoin
- * - Performing alias substitution.
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
+ // Push down project through limit, so that we may have chance to push it further.
case Project(projectList, Limit(exp, child)) =>
Limit(exp, Project(projectList, child))
- // push down project if possible when the child is sort
+ // Push down project if possible when the child is sort
case p @ Project(projectList, s @ Sort(_, _, grandChild))
if s.references.subsetOf(p.outputSet) =>
s.copy(child = Project(projectList, grandChild))
@@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/**
- * Combines two adjacent [[Project]] operators into one, merging the
- * expressions into one single expression.
+ * Combines two adjacent [[Project]] operators into one and perform alias substitution,
+ * merging the expressions into one single expression.
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
@@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
object LikeSimplification extends Rule[LogicalPlan] {
// if guards below protect from escapes on trailing %.
// Cases like "something\%" are not optimized, but this does not affect correctness.
- val startsWith = "([^_%]+)%".r
- val endsWith = "%([^_%]+)".r
- val contains = "%([^_%]+)%".r
- val equalTo = "([^_%]*)".r
+ private val startsWith = "([^_%]+)%".r
+ private val endsWith = "%([^_%]+)".r
+ private val contains = "%([^_%]+)%".r
+ private val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(l, Literal(utf, StringType)) =>
@@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
grandChild))
}
- def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = {
+ private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = {
condition transform {
case a: AttributeReference => sourceAliases.getOrElse(a, a)
}
@@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
import Decimal.MAX_LONG_DIGITS
/** Maximum number of decimal digits representable precisely in a Double */
- val MAX_DOUBLE_DIGITS = 15
+ private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 04857a23f4c1e..8656cc334d09f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -48,6 +48,15 @@ object TypeUtils {
}
}
+ def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
+ if (types.distinct.size > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 407dc27326c2e..18cdfa7238f39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -20,13 +20,18 @@ package org.apache.spark.sql.types
import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.Expression
/** Precision parameters for a Decimal */
-case class PrecisionInfo(precision: Int, scale: Int)
-
+case class PrecisionInfo(precision: Int, scale: Int) {
+ if (scale > precision) {
+ throw new AnalysisException(
+ s"Decimal scale ($scale) cannot be greater than precision ($precision).")
+ }
+}
/**
* :: DeveloperApi ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e09cd790a7187..77ca080f366cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
- "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+ "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"non-boolean filters",
@@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val plan =
Aggregate(
Nil,
- Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
+ Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
- AttributeReference("a", StringType)(exprId = ExprId(2))))
+ AttributeReference("a", IntegerType)(exprId = ExprId(2))))
assert(plan.resolved)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
similarity index 84%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 49b111989799b..bc1537b0715b5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.StringType
@@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
"WHEN expressions in CaseWhen should all be boolean type")
+ }
+
+ test("check types for aggregates") {
+ // We will cast String to Double for sum and average
+ assertSuccess(Sum('stringField))
+ assertSuccess(SumDistinct('stringField))
+ assertSuccess(Average('stringField))
+
+ assertError(Min('complexField), "function min accepts non-complex type")
+ assertError(Max('complexField), "function max accepts non-complex type")
+ assertError(Sum('booleanField), "function sum accepts numeric type")
+ assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
+ assertError(Average('booleanField), "function average accepts numeric type")
+ }
+ test("check types for others") {
+ assertError(CreateArray(Seq('intField, 'booleanField)),
+ "input to function array should all be the same type")
+ assertError(Coalesce(Seq('intField, 'booleanField)),
+ "input to function coalesce should all be the same type")
+ assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
+ assertError(Explode('intField),
+ "input to function explode should be array or map type")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
index 35f50be46b76f..ec379489a6d1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
@@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
-class UnionPushdownSuite extends PlanTest {
+class UnionPushDownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
- UnionPushdown) :: Nil
+ UnionPushDown) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 265352647fa9f..9a10a23937fbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -264,6 +264,14 @@ private[spark] object SQLConf {
defaultValue = Some(true),
doc = "
")
+ val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf(
+ key = "spark.sql.parquet.followParquetFormatSpec",
+ defaultValue = Some(false),
+ doc = "Wether to stick to Parquet format specification when converting Parquet schema to " +
+ "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " +
+ "to compatible mode if set to false.",
+ isPublic = false)
+
val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf(
key = "spark.sql.parquet.output.committer.class",
defaultValue = Some(classOf[ParquetOutputCommitter].getName),
@@ -498,6 +506,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*/
private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP)
+ /**
+ * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL
+ * schema and vice versa. Otherwise, falls back to compatible mode.
+ */
+ private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC)
+
/**
* When set to true, partition pruning for in-memory columnar tables is enabled.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 04fc798bf3738..5708df82de12f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
- TakeOrdered ::
+ TakeOrderedAndProject ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2b8d30294293c..47f56b2b7ebe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
-
GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1ff1cc224de8c..21912cf24933e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
- object TakeOrdered extends Strategy {
+ object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrdered(limit, order, planLater(child)) :: Nil
+ execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 7aedd630e3871..647c4ab5cb651 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- val resuableProjection = buildProjection()
- iter.map(resuableProjection)
+ val reusableProjection = buildProjection()
+ iter.map(reusableProjection)
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan)
/**
* :: DeveloperApi ::
- * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
- * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
- * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
+case class TakeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
@@ -160,8 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
- private def collectData(): Array[InternalRow] =
- child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
+ @transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
+
+ private def collectData(): Array[InternalRow] = {
+ val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ projection.map(data.map(_)).getOrElse(data)
+ }
override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 6db551c543a9c..f9c3fe92c2670 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -55,7 +55,7 @@ private[spark] case class PythonUDF(
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
- def nullable: Boolean = true
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
new file mode 100644
index 0000000000000..4fd3e93b70311
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
@@ -0,0 +1,565 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.schema.OriginalType._
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+import org.apache.parquet.schema.Type.Repetition._
+import org.apache.parquet.schema._
+
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{AnalysisException, SQLConf}
+
+/**
+ * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and
+ * vice versa.
+ *
+ * Parquet format backwards-compatibility rules are respected when converting Parquet
+ * [[MessageType]] schemas.
+ *
+ * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md
+ *
+ * @constructor
+ * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL
+ * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL
+ * [[StructType]].
+ * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL
+ * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL
+ * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which
+ * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS`
+ * described in Parquet format spec.
+ * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when
+ * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and
+ * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and
+ * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is
+ * backwards-compatible with these settings. If this argument is set to `false`, we fallback
+ * to old style non-standard behaviors.
+ */
+private[parquet] class CatalystSchemaConverter(
+ private val assumeBinaryIsString: Boolean,
+ private val assumeInt96IsTimestamp: Boolean,
+ private val followParquetFormatSpec: Boolean) {
+
+ // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in
+ // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant.
+ def this() = this(
+ assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
+ assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+ followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)
+
+ def this(conf: SQLConf) = this(
+ assumeBinaryIsString = conf.isParquetBinaryAsString,
+ assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
+ followParquetFormatSpec = conf.followParquetFormatSpec)
+
+ def this(conf: Configuration) = this(
+ assumeBinaryIsString =
+ conf.getBoolean(
+ SQLConf.PARQUET_BINARY_AS_STRING.key,
+ SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get),
+ assumeInt96IsTimestamp =
+ conf.getBoolean(
+ SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
+ SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get),
+ followParquetFormatSpec =
+ conf.getBoolean(
+ SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key,
+ SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get))
+
+ /**
+ * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
+ */
+ def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType())
+
+ private def convert(parquetSchema: GroupType): StructType = {
+ val fields = parquetSchema.getFields.map { field =>
+ field.getRepetition match {
+ case OPTIONAL =>
+ StructField(field.getName, convertField(field), nullable = true)
+
+ case REQUIRED =>
+ StructField(field.getName, convertField(field), nullable = false)
+
+ case REPEATED =>
+ throw new AnalysisException(
+ s"REPEATED not supported outside LIST or MAP. Type: $field")
+ }
+ }
+
+ StructType(fields)
+ }
+
+ /**
+ * Converts a Parquet [[Type]] to a Spark SQL [[DataType]].
+ */
+ def convertField(parquetType: Type): DataType = parquetType match {
+ case t: PrimitiveType => convertPrimitiveField(t)
+ case t: GroupType => convertGroupField(t.asGroupType())
+ }
+
+ private def convertPrimitiveField(field: PrimitiveType): DataType = {
+ val typeName = field.getPrimitiveTypeName
+ val originalType = field.getOriginalType
+
+ def typeString =
+ if (originalType == null) s"$typeName" else s"$typeName ($originalType)"
+
+ def typeNotImplemented() =
+ throw new AnalysisException(s"Parquet type not yet supported: $typeString")
+
+ def illegalType() =
+ throw new AnalysisException(s"Illegal Parquet type: $typeString")
+
+ // When maxPrecision = -1, we skip precision range check, and always respect the precision
+ // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored
+ // as binaries with variable lengths.
+ def makeDecimalType(maxPrecision: Int = -1): DecimalType = {
+ val precision = field.getDecimalMetadata.getPrecision
+ val scale = field.getDecimalMetadata.getScale
+
+ CatalystSchemaConverter.analysisRequire(
+ maxPrecision == -1 || 1 <= precision && precision <= maxPrecision,
+ s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)")
+
+ DecimalType(precision, scale)
+ }
+
+ field.getPrimitiveTypeName match {
+ case BOOLEAN => BooleanType
+
+ case FLOAT => FloatType
+
+ case DOUBLE => DoubleType
+
+ case INT32 =>
+ field.getOriginalType match {
+ case INT_8 => ByteType
+ case INT_16 => ShortType
+ case INT_32 | null => IntegerType
+ case DATE => DateType
+ case DECIMAL => makeDecimalType(maxPrecisionForBytes(4))
+ case TIME_MILLIS => typeNotImplemented()
+ case _ => illegalType()
+ }
+
+ case INT64 =>
+ field.getOriginalType match {
+ case INT_64 | null => LongType
+ case DECIMAL => makeDecimalType(maxPrecisionForBytes(8))
+ case TIMESTAMP_MILLIS => typeNotImplemented()
+ case _ => illegalType()
+ }
+
+ case INT96 =>
+ CatalystSchemaConverter.analysisRequire(
+ assumeInt96IsTimestamp,
+ "INT96 is not supported unless it's interpreted as timestamp. " +
+ s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
+ TimestampType
+
+ case BINARY =>
+ field.getOriginalType match {
+ case UTF8 => StringType
+ case null if assumeBinaryIsString => StringType
+ case null => BinaryType
+ case DECIMAL => makeDecimalType()
+ case _ => illegalType()
+ }
+
+ case FIXED_LEN_BYTE_ARRAY =>
+ field.getOriginalType match {
+ case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength))
+ case INTERVAL => typeNotImplemented()
+ case _ => illegalType()
+ }
+
+ case _ => illegalType()
+ }
+ }
+
+ private def convertGroupField(field: GroupType): DataType = {
+ Option(field.getOriginalType).fold(convert(field): DataType) {
+ // A Parquet list is represented as a 3-level structure:
+ //
+ // group (LIST) {
+ // repeated group list {
+ // element;
+ // }
+ // }
+ //
+ // However, according to the most recent Parquet format spec (not released yet up until
+ // writing), some 2-level structures are also recognized for backwards-compatibility. Thus,
+ // we need to check whether the 2nd level or the 3rd level refers to list element type.
+ //
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists
+ case LIST =>
+ CatalystSchemaConverter.analysisRequire(
+ field.getFieldCount == 1, s"Invalid list type $field")
+
+ val repeatedType = field.getType(0)
+ CatalystSchemaConverter.analysisRequire(
+ repeatedType.isRepetition(REPEATED), s"Invalid list type $field")
+
+ if (isElementType(repeatedType, field.getName)) {
+ ArrayType(convertField(repeatedType), containsNull = false)
+ } else {
+ val elementType = repeatedType.asGroupType().getType(0)
+ val optional = elementType.isRepetition(OPTIONAL)
+ ArrayType(convertField(elementType), containsNull = optional)
+ }
+
+ // scalastyle:off
+ // `MAP_KEY_VALUE` is for backwards-compatibility
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1
+ // scalastyle:on
+ case MAP | MAP_KEY_VALUE =>
+ CatalystSchemaConverter.analysisRequire(
+ field.getFieldCount == 1 && !field.getType(0).isPrimitive,
+ s"Invalid map type: $field")
+
+ val keyValueType = field.getType(0).asGroupType()
+ CatalystSchemaConverter.analysisRequire(
+ keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2,
+ s"Invalid map type: $field")
+
+ val keyType = keyValueType.getType(0)
+ CatalystSchemaConverter.analysisRequire(
+ keyType.isPrimitive,
+ s"Map key type is expected to be a primitive type, but found: $keyType")
+
+ val valueType = keyValueType.getType(1)
+ val valueOptional = valueType.isRepetition(OPTIONAL)
+ MapType(
+ convertField(keyType),
+ convertField(valueType),
+ valueContainsNull = valueOptional)
+
+ case _ =>
+ throw new AnalysisException(s"Unrecognized Parquet type: $field")
+ }
+ }
+
+ // scalastyle:off
+ // Here we implement Parquet LIST backwards-compatibility rules.
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules
+ // scalastyle:on
+ private def isElementType(repeatedType: Type, parentName: String) = {
+ {
+ // For legacy 2-level list types with primitive element type, e.g.:
+ //
+ // // List (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated int32 element;
+ // }
+ //
+ repeatedType.isPrimitive
+ } || {
+ // For legacy 2-level list types whose element type is a group type with 2 or more fields,
+ // e.g.:
+ //
+ // // List> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group element {
+ // required binary str (UTF8);
+ // required int32 num;
+ // };
+ // }
+ //
+ repeatedType.asGroupType().getFieldCount > 1
+ } || {
+ // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.:
+ //
+ // // List> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group array {
+ // required binary str (UTF8);
+ // };
+ // }
+ //
+ repeatedType.getName == "array"
+ } || {
+ // For Parquet data generated by parquet-thrift, e.g.:
+ //
+ // // List> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group my_list_tuple {
+ // required binary str (UTF8);
+ // };
+ // }
+ //
+ repeatedType.getName == s"${parentName}_tuple"
+ }
+ }
+
+ /**
+ * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]].
+ */
+ def convert(catalystSchema: StructType): MessageType = {
+ Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root")
+ }
+
+ /**
+ * Converts a Spark SQL [[StructField]] to a Parquet [[Type]].
+ */
+ def convertField(field: StructField): Type = {
+ convertField(field, if (field.nullable) OPTIONAL else REQUIRED)
+ }
+
+ private def convertField(field: StructField, repetition: Type.Repetition): Type = {
+ CatalystSchemaConverter.checkFieldName(field.name)
+
+ field.dataType match {
+ // ===================
+ // Simple atomic types
+ // ===================
+
+ case BooleanType =>
+ Types.primitive(BOOLEAN, repetition).named(field.name)
+
+ case ByteType =>
+ Types.primitive(INT32, repetition).as(INT_8).named(field.name)
+
+ case ShortType =>
+ Types.primitive(INT32, repetition).as(INT_16).named(field.name)
+
+ case IntegerType =>
+ Types.primitive(INT32, repetition).named(field.name)
+
+ case LongType =>
+ Types.primitive(INT64, repetition).named(field.name)
+
+ case FloatType =>
+ Types.primitive(FLOAT, repetition).named(field.name)
+
+ case DoubleType =>
+ Types.primitive(DOUBLE, repetition).named(field.name)
+
+ case StringType =>
+ Types.primitive(BINARY, repetition).as(UTF8).named(field.name)
+
+ case DateType =>
+ Types.primitive(INT32, repetition).as(DATE).named(field.name)
+
+ // NOTE: !! This timestamp type is not specified in Parquet format spec !!
+ // However, Impala and older versions of Spark SQL use INT96 to store timestamps with
+ // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec).
+ case TimestampType =>
+ Types.primitive(INT96, repetition).named(field.name)
+
+ case BinaryType =>
+ Types.primitive(BINARY, repetition).named(field.name)
+
+ // =====================================
+ // Decimals (for Spark version <= 1.4.x)
+ // =====================================
+
+ // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
+ // always store decimals in fixed-length byte arrays.
+ case DecimalType.Fixed(precision, scale)
+ if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec =>
+ Types
+ .primitive(FIXED_LEN_BYTE_ARRAY, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .length(minBytesForPrecision(precision))
+ .named(field.name)
+
+ case dec @ DecimalType() if !followParquetFormatSpec =>
+ throw new AnalysisException(
+ s"Data type $dec is not supported. " +
+ s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," +
+ "decimal precision and scale must be specified, " +
+ "and precision must be less than or equal to 18.")
+
+ // =====================================
+ // Decimals (follow Parquet format spec)
+ // =====================================
+
+ // Uses INT32 for 1 <= precision <= 9
+ case DecimalType.Fixed(precision, scale)
+ if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec =>
+ Types
+ .primitive(INT32, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .named(field.name)
+
+ // Uses INT64 for 1 <= precision <= 18
+ case DecimalType.Fixed(precision, scale)
+ if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec =>
+ Types
+ .primitive(INT64, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .named(field.name)
+
+ // Uses FIXED_LEN_BYTE_ARRAY for all other precisions
+ case DecimalType.Fixed(precision, scale) if followParquetFormatSpec =>
+ Types
+ .primitive(FIXED_LEN_BYTE_ARRAY, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .length(minBytesForPrecision(precision))
+ .named(field.name)
+
+ case dec @ DecimalType.Unlimited if followParquetFormatSpec =>
+ throw new AnalysisException(
+ s"Data type $dec is not supported. Decimal precision and scale must be specified.")
+
+ // ===================================================
+ // ArrayType and MapType (for Spark versions <= 1.4.x)
+ // ===================================================
+
+ // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level
+ // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is
+ // covered by the backwards-compatibility rules implemented in `isElementType()`.
+ case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec =>
+ // group (LIST) {
+ // optional group bag {
+ // repeated element;
+ // }
+ // }
+ ConversionPatterns.listType(
+ repetition,
+ field.name,
+ Types
+ .buildGroup(REPEATED)
+ .addField(convertField(StructField("element", elementType, nullable)))
+ .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME))
+
+ // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level
+ // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is
+ // covered by the backwards-compatibility rules implemented in `isElementType()`.
+ case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec =>
+ // group (LIST) {
+ // repeated element;
+ // }
+ ConversionPatterns.listType(
+ repetition,
+ field.name,
+ convertField(StructField("element", elementType, nullable), REPEATED))
+
+ // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by
+ // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`.
+ case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec =>
+ // group (MAP) {
+ // repeated group map (MAP_KEY_VALUE) {
+ // required key;
+ // value;
+ // }
+ // }
+ ConversionPatterns.mapType(
+ repetition,
+ field.name,
+ convertField(StructField("key", keyType, nullable = false)),
+ convertField(StructField("value", valueType, valueContainsNull)))
+
+ // ==================================================
+ // ArrayType and MapType (follow Parquet format spec)
+ // ==================================================
+
+ case ArrayType(elementType, containsNull) if followParquetFormatSpec =>
+ // group (LIST) {
+ // repeated group list {
+ // element;
+ // }
+ // }
+ Types
+ .buildGroup(repetition).as(LIST)
+ .addField(
+ Types.repeatedGroup()
+ .addField(convertField(StructField("element", elementType, containsNull)))
+ .named("list"))
+ .named(field.name)
+
+ case MapType(keyType, valueType, valueContainsNull) =>
+ // group (MAP) {
+ // repeated group key_value {
+ // required key;
+ // value;
+ // }
+ // }
+ Types
+ .buildGroup(repetition).as(MAP)
+ .addField(
+ Types
+ .repeatedGroup()
+ .addField(convertField(StructField("key", keyType, nullable = false)))
+ .addField(convertField(StructField("value", valueType, valueContainsNull)))
+ .named("key_value"))
+ .named(field.name)
+
+ // ===========
+ // Other types
+ // ===========
+
+ case StructType(fields) =>
+ fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) =>
+ builder.addField(convertField(field))
+ }.named(field.name)
+
+ case udt: UserDefinedType[_] =>
+ convertField(field.copy(dataType = udt.sqlType))
+
+ case _ =>
+ throw new AnalysisException(s"Unsupported data type $field.dataType")
+ }
+ }
+
+ // Max precision of a decimal value stored in `numBytes` bytes
+ private def maxPrecisionForBytes(numBytes: Int): Int = {
+ Math.round( // convert double to long
+ Math.floor(Math.log10( // number of base-10 digits
+ Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
+ .asInstanceOf[Int]
+ }
+
+ // Min byte counts needed to store decimals with various precisions
+ private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision =>
+ var numBytes = 1
+ while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
+ numBytes += 1
+ }
+ numBytes
+ }
+}
+
+
+private[parquet] object CatalystSchemaConverter {
+ def checkFieldName(name: String): Unit = {
+ // ,;{}()\n\t= and space are special characters in Parquet schema
+ analysisRequire(
+ !name.matches(".*[ ,;{}()\n\t=].*"),
+ s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
+ |Please use alias to rename it.
+ """.stripMargin.split("\n").mkString(" "))
+ }
+
+ def analysisRequire(f: => Boolean, message: String): Unit = {
+ if (!f) {
+ throw new AnalysisException(message)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index e65fa0030e179..0d96a1e8070b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -86,8 +86,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg
// TODO: Why it can be null?
if (schema == null) {
log.debug("falling back to Parquet read schema")
- schema = ParquetTypesConverter.convertToAttributes(
- parquetSchema, false, true)
+ schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true)
}
log.debug(s"list of attributes that will be read: $schema")
new RowRecordMaterializer(parquetSchema, schema)
@@ -105,8 +104,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg
// If the parquet file is thrift derived, there is a good chance that
// it will have the thrift class in metadata.
val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class")
- parquetSchema = ParquetTypesConverter
- .convertFromAttributes(requestedAttributes, isThriftDerived)
+ parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes)
metadata.put(
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
ParquetTypesConverter.convertToString(requestedAttributes))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index ba2a35b74ef82..4d5199a140344 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -29,214 +29,19 @@ import org.apache.parquet.format.converter.ParquetMetadataConverter
import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter}
-import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
-import org.apache.parquet.schema.Type.Repetition
-import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes}
+import org.apache.parquet.schema.MessageType
import org.apache.spark.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types._
-/** A class representing Parquet info fields we care about, for passing back to Parquet */
-private[parquet] case class ParquetTypeInfo(
- primitiveType: ParquetPrimitiveTypeName,
- originalType: Option[ParquetOriginalType] = None,
- decimalMetadata: Option[DecimalMetadata] = None,
- length: Option[Int] = None)
-
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean = ctype match {
case _: NumericType | BooleanType | StringType | BinaryType => true
case _: DataType => false
}
- def toPrimitiveDataType(
- parquetType: ParquetPrimitiveType,
- binaryAsString: Boolean,
- int96AsTimestamp: Boolean): DataType = {
- val originalType = parquetType.getOriginalType
- val decimalInfo = parquetType.getDecimalMetadata
- parquetType.getPrimitiveTypeName match {
- case ParquetPrimitiveTypeName.BINARY
- if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType
- case ParquetPrimitiveTypeName.BINARY => BinaryType
- case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
- case ParquetPrimitiveTypeName.DOUBLE => DoubleType
- case ParquetPrimitiveTypeName.FLOAT => FloatType
- case ParquetPrimitiveTypeName.INT32
- if originalType == ParquetOriginalType.DATE => DateType
- case ParquetPrimitiveTypeName.INT32 => IntegerType
- case ParquetPrimitiveTypeName.INT64 => LongType
- case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType
- case ParquetPrimitiveTypeName.INT96 =>
- // TODO: add BigInteger type? TODO(andre) use DecimalType instead????
- throw new AnalysisException("Potential loss of precision: cannot convert INT96")
- case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
- if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) =>
- // TODO: for now, our reader only supports decimals that fit in a Long
- DecimalType(decimalInfo.getPrecision, decimalInfo.getScale)
- case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType")
- }
- }
-
- /**
- * Converts a given Parquet `Type` into the corresponding
- * [[org.apache.spark.sql.types.DataType]].
- *
- * We apply the following conversion rules:
- *
- * - Primitive types are converter to the corresponding primitive type.
- * - Group types that have a single field that is itself a group, which has repetition
- * level `REPEATED`, are treated as follows:
- * - If the nested group has name `values`, the surrounding group is converted
- * into an [[ArrayType]] with the corresponding field type (primitive or
- * complex) as element type.
- * - If the nested group has name `map` and two fields (named `key` and `value`),
- * the surrounding group is converted into a [[MapType]]
- * with the corresponding key and value (value possibly complex) types.
- * Note that we currently assume map values are not nullable.
- * - Other group types are converted into a [[StructType]] with the corresponding
- * field types.
- *
- * Note that fields are determined to be `nullable` if and only if their Parquet repetition
- * level is not `REQUIRED`.
- *
- * @param parquetType The type to convert.
- * @return The corresponding Catalyst type.
- */
- def toDataType(parquetType: ParquetType,
- isBinaryAsString: Boolean,
- isInt96AsTimestamp: Boolean): DataType = {
- def correspondsToMap(groupType: ParquetGroupType): Boolean = {
- if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) {
- false
- } else {
- // This mostly follows the convention in ``parquet.schema.ConversionPatterns``
- val keyValueGroup = groupType.getFields.apply(0).asGroupType()
- keyValueGroup.getRepetition == Repetition.REPEATED &&
- keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME &&
- keyValueGroup.getFieldCount == 2 &&
- keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME &&
- keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME
- }
- }
-
- def correspondsToArray(groupType: ParquetGroupType): Boolean = {
- groupType.getFieldCount == 1 &&
- groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME &&
- groupType.getFields.apply(0).getRepetition == Repetition.REPEATED
- }
-
- if (parquetType.isPrimitive) {
- toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp)
- } else {
- val groupType = parquetType.asGroupType()
- parquetType.getOriginalType match {
- // if the schema was constructed programmatically there may be hints how to convert
- // it inside the metadata via the OriginalType field
- case ParquetOriginalType.LIST => { // TODO: check enums!
- assert(groupType.getFieldCount == 1)
- val field = groupType.getFields.apply(0)
- if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) {
- val bag = field.asGroupType()
- assert(bag.getFieldCount == 1)
- ArrayType(
- toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp),
- containsNull = true)
- } else {
- ArrayType(
- toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false)
- }
- }
- case ParquetOriginalType.MAP => {
- assert(
- !groupType.getFields.apply(0).isPrimitive,
- "Parquet Map type malformatted: expected nested group for map!")
- val keyValueGroup = groupType.getFields.apply(0).asGroupType()
- assert(
- keyValueGroup.getFieldCount == 2,
- "Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
- assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
-
- val keyType =
- toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp)
- val valueType =
- toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp)
- MapType(keyType, valueType,
- keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
- }
- case _ => {
- // Note: the order of these checks is important!
- if (correspondsToMap(groupType)) { // MapType
- val keyValueGroup = groupType.getFields.apply(0).asGroupType()
- assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
-
- val keyType =
- toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp)
- val valueType =
- toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp)
- MapType(keyType, valueType,
- keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED)
- } else if (correspondsToArray(groupType)) { // ArrayType
- val field = groupType.getFields.apply(0)
- if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) {
- val bag = field.asGroupType()
- assert(bag.getFieldCount == 1)
- ArrayType(
- toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp),
- containsNull = true)
- } else {
- ArrayType(
- toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false)
- }
- } else { // everything else: StructType
- val fields = groupType
- .getFields
- .map(ptype => new StructField(
- ptype.getName,
- toDataType(ptype, isBinaryAsString, isInt96AsTimestamp),
- ptype.getRepetition != Repetition.REQUIRED))
- StructType(fields)
- }
- }
- }
- }
- }
-
- /**
- * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return
- * the name of the corresponding Parquet primitive type or None if the given type
- * is not primitive.
- *
- * @param ctype The type to convert
- * @return The name of the corresponding Parquet type properties
- */
- def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match {
- case StringType => Some(ParquetTypeInfo(
- ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)))
- case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY))
- case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN))
- case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE))
- case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT))
- case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
- // There is no type for Byte or Short so we promote them to INT32.
- case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
- case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
- case DateType => Some(ParquetTypeInfo(
- ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE)))
- case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64))
- case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96))
- case DecimalType.Fixed(precision, scale) if precision <= 18 =>
- // TODO: for now, our writer only supports decimals that fit in a Long
- Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY,
- Some(ParquetOriginalType.DECIMAL),
- Some(new DecimalMetadata(precision, scale)),
- Some(BYTES_FOR_PRECISION(precision))))
- case _ => None
- }
-
/**
* Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision.
*/
@@ -248,177 +53,29 @@ private[parquet] object ParquetTypesConverter extends Logging {
length
}
- /**
- * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into
- * the corresponding Parquet `Type`.
- *
- * The conversion follows the rules below:
- *
- * - Primitive types are converted into Parquet's primitive types.
- * - [[org.apache.spark.sql.types.StructType]]s are converted
- * into Parquet's `GroupType` with the corresponding field types.
- * - [[org.apache.spark.sql.types.ArrayType]]s are converted
- * into a 2-level nested group, where the outer group has the inner
- * group as sole field. The inner group has name `values` and
- * repetition level `REPEATED` and has the element type of
- * the array as schema. We use Parquet's `ConversionPatterns` for this
- * purpose.
- * - [[org.apache.spark.sql.types.MapType]]s are converted
- * into a nested (2-level) Parquet `GroupType` with two fields: a key
- * type and a value type. The nested group has repetition level
- * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns`
- * for this purpose
- *
- * Parquet's repetition level is generally set according to the following rule:
- *
- * - If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or
- * `MapType`, then the repetition level is set to `REPEATED`.
- * - Otherwise, if the attribute whose type is converted is `nullable`, the Parquet
- * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
- *
- *
- *@param ctype The type to convert
- * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]]
- * whose type is converted
- * @param nullable When true indicates that the attribute is nullable
- * @param inArray When true indicates that this is a nested attribute inside an array.
- * @return The corresponding Parquet type.
- */
- def fromDataType(
- ctype: DataType,
- name: String,
- nullable: Boolean = true,
- inArray: Boolean = false,
- toThriftSchemaNames: Boolean = false): ParquetType = {
- val repetition =
- if (inArray) {
- Repetition.REPEATED
- } else {
- if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED
- }
- val arraySchemaName = if (toThriftSchemaNames) {
- name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX
- } else {
- CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME
- }
- val typeInfo = fromPrimitiveDataType(ctype)
- typeInfo.map {
- case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) =>
- val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull)
- for (len <- length) {
- builder.length(len)
- }
- for (metadata <- decimalMetadata) {
- builder.precision(metadata.getPrecision).scale(metadata.getScale)
- }
- builder.named(name)
- }.getOrElse {
- ctype match {
- case udt: UserDefinedType[_] => {
- fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames)
- }
- case ArrayType(elementType, false) => {
- val parquetElementType = fromDataType(
- elementType,
- arraySchemaName,
- nullable = false,
- inArray = true,
- toThriftSchemaNames)
- ConversionPatterns.listType(repetition, name, parquetElementType)
- }
- case ArrayType(elementType, true) => {
- val parquetElementType = fromDataType(
- elementType,
- arraySchemaName,
- nullable = true,
- inArray = false,
- toThriftSchemaNames)
- ConversionPatterns.listType(
- repetition,
- name,
- new ParquetGroupType(
- Repetition.REPEATED,
- CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME,
- parquetElementType))
- }
- case StructType(structFields) => {
- val fields = structFields.map {
- field => fromDataType(field.dataType, field.name, field.nullable,
- inArray = false, toThriftSchemaNames)
- }
- new ParquetGroupType(repetition, name, fields.toSeq)
- }
- case MapType(keyType, valueType, valueContainsNull) => {
- val parquetKeyType =
- fromDataType(
- keyType,
- CatalystConverter.MAP_KEY_SCHEMA_NAME,
- nullable = false,
- inArray = false,
- toThriftSchemaNames)
- val parquetValueType =
- fromDataType(
- valueType,
- CatalystConverter.MAP_VALUE_SCHEMA_NAME,
- nullable = valueContainsNull,
- inArray = false,
- toThriftSchemaNames)
- ConversionPatterns.mapType(
- repetition,
- name,
- parquetKeyType,
- parquetValueType)
- }
- case _ => throw new AnalysisException(s"Unsupported datatype $ctype")
- }
- }
- }
-
- def convertToAttributes(parquetSchema: ParquetType,
- isBinaryAsString: Boolean,
- isInt96AsTimestamp: Boolean): Seq[Attribute] = {
- parquetSchema
- .asGroupType()
- .getFields
- .map(
- field =>
- new AttributeReference(
- field.getName,
- toDataType(field, isBinaryAsString, isInt96AsTimestamp),
- field.getRepetition != Repetition.REQUIRED)())
+ def convertToAttributes(
+ parquetSchema: MessageType,
+ isBinaryAsString: Boolean,
+ isInt96AsTimestamp: Boolean): Seq[Attribute] = {
+ val converter = new CatalystSchemaConverter(
+ isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false)
+ converter.convert(parquetSchema).toAttributes
}
- def convertFromAttributes(attributes: Seq[Attribute],
- toThriftSchemaNames: Boolean = false): MessageType = {
- checkSpecialCharacters(attributes)
- val fields = attributes.map(
- attribute =>
- fromDataType(attribute.dataType, attribute.name, attribute.nullable,
- toThriftSchemaNames = toThriftSchemaNames))
- new MessageType("root", fields)
+ def convertFromAttributes(attributes: Seq[Attribute]): MessageType = {
+ val converter = new CatalystSchemaConverter()
+ converter.convert(StructType.fromAttributes(attributes))
}
def convertFromString(string: String): Seq[Attribute] = {
Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
- case other => throw new AnalysisException(s"Can convert $string to row")
- }
- }
-
- private def checkSpecialCharacters(schema: Seq[Attribute]) = {
- // ,;{}()\n\t= and space character are special characters in Parquet schema
- schema.map(_.name).foreach { name =>
- if (name.matches(".*[ ,;{}()\n\t=].*")) {
- throw new AnalysisException(
- s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
- |Please use alias to rename it.
- """.stripMargin.split("\n").mkString(" "))
- }
+ case other => sys.error(s"Can convert $string to row")
}
}
def convertToString(schema: Seq[Attribute]): String = {
- checkSpecialCharacters(schema)
+ schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName)
StructType.fromAttributes(schema).json
}
@@ -450,8 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
ParquetTypesConverter.convertToString(attributes))
// TODO: add extra data, e.g., table name, date, etc.?
- val parquetSchema: MessageType =
- ParquetTypesConverter.convertFromAttributes(attributes)
+ val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes)
val metaData: FileMetaData = new FileMetaData(
parquetSchema,
extraMetadata,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 1d353bd8e1114..bc39fae2bcfde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -194,6 +194,12 @@ private[sql] class ParquetRelation2(
committerClass,
classOf[ParquetOutputCommitter])
+ // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override
+ // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why
+ // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is
+ // bundled with `ParquetOutputFormat[Row]`.
+ job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]])
+
// TODO There's no need to use two kinds of WriteSupport
// We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and
// complex types.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 215e53c020849..fb6173f58ece6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -96,7 +96,8 @@ private[sql] case class InsertIntoHadoopFsRelation(
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match {
+ val pathExists = fs.exists(qualifiedOutputPath)
+ val doInsertion = (mode, pathExists) match {
case (SaveMode.ErrorIfExists, true) =>
sys.error(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
@@ -107,6 +108,8 @@ private[sql] case class InsertIntoHadoopFsRelation(
case (SaveMode.Ignore, exists) =>
!exists
}
+ // If we are appending data to an existing dir.
+ val isAppend = (pathExists) && (mode == SaveMode.Append)
if (doInsertion) {
val job = new Job(hadoopConf)
@@ -130,10 +133,10 @@ private[sql] case class InsertIntoHadoopFsRelation(
val partitionColumns = relation.partitionColumns.fieldNames
if (partitionColumns.isEmpty) {
- insert(new DefaultWriterContainer(relation, job), df)
+ insert(new DefaultWriterContainer(relation, job, isAppend), df)
} else {
val writerContainer = new DynamicPartitionWriterContainer(
- relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME)
+ relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend)
insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
}
}
@@ -277,7 +280,8 @@ private[sql] case class InsertIntoHadoopFsRelation(
private[sql] abstract class BaseWriterContainer(
@transient val relation: HadoopFsRelation,
- @transient job: Job)
+ @transient job: Job,
+ isAppend: Boolean)
extends SparkHadoopMapReduceUtil
with Logging
with Serializable {
@@ -356,34 +360,47 @@ private[sql] abstract class BaseWriterContainer(
}
private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = {
- val committerClass = context.getConfiguration.getClass(
- SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
-
- Option(committerClass).map { clazz =>
- logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
-
- // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
- // has an associated output committer. To override this output committer,
- // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
- // If a data source needs to override the output committer, it needs to set the
- // output committer in prepareForWrite method.
- if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) {
- // The specified output committer is a FileOutputCommitter.
- // So, we will use the FileOutputCommitter-specified constructor.
- val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
- ctor.newInstance(new Path(outputPath), context)
- } else {
- // The specified output committer is just a OutputCommitter.
- // So, we will use the no-argument constructor.
- val ctor = clazz.getDeclaredConstructor()
- ctor.newInstance()
+ val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
+
+ if (isAppend) {
+ // If we are appending data to an existing dir, we will only use the output committer
+ // associated with the file output format since it is not safe to use a custom
+ // committer for appending. For example, in S3, direct parquet output committer may
+ // leave partial data in the destination dir when the the appending job fails.
+ logInfo(
+ s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " +
+ "for appending.")
+ defaultOutputCommitter
+ } else {
+ val committerClass = context.getConfiguration.getClass(
+ SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
+
+ Option(committerClass).map { clazz =>
+ logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
+
+ // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
+ // has an associated output committer. To override this output committer,
+ // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
+ // If a data source needs to override the output committer, it needs to set the
+ // output committer in prepareForWrite method.
+ if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) {
+ // The specified output committer is a FileOutputCommitter.
+ // So, we will use the FileOutputCommitter-specified constructor.
+ val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
+ ctor.newInstance(new Path(outputPath), context)
+ } else {
+ // The specified output committer is just a OutputCommitter.
+ // So, we will use the no-argument constructor.
+ val ctor = clazz.getDeclaredConstructor()
+ ctor.newInstance()
+ }
+ }.getOrElse {
+ // If output committer class is not set, we will use the one associated with the
+ // file output format.
+ logInfo(
+ s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}")
+ defaultOutputCommitter
}
- }.getOrElse {
- // If output committer class is not set, we will use the one associated with the
- // file output format.
- val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
- logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}")
- outputCommitter
}
}
@@ -433,8 +450,9 @@ private[sql] abstract class BaseWriterContainer(
private[sql] class DefaultWriterContainer(
@transient relation: HadoopFsRelation,
- @transient job: Job)
- extends BaseWriterContainer(relation, job) {
+ @transient job: Job,
+ isAppend: Boolean)
+ extends BaseWriterContainer(relation, job, isAppend) {
@transient private var writer: OutputWriter = _
@@ -473,8 +491,9 @@ private[sql] class DynamicPartitionWriterContainer(
@transient relation: HadoopFsRelation,
@transient job: Job,
partitionColumns: Array[String],
- defaultPartitionName: String)
- extends BaseWriterContainer(relation, job) {
+ defaultPartitionName: String,
+ isAppend: Boolean)
+ extends BaseWriterContainer(relation, job, isAppend) {
// All output writers are created on executor side.
@transient protected var outputWriters: mutable.Map[String, OutputWriter] = _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5854ab48db552..3dd24130af81a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
+
+ test("efficient limit -> project -> sort") {
+ val query = testData.sort('key).select('value).limit(2).logicalPlan
+ val planned = planner.TakeOrderedAndProject(query)
+ assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 47a7be1c6a664..7b16eba00d6fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -99,7 +99,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
test("fixed-length decimals") {
-
def makeDecimalRDD(decimal: DecimalType): DataFrame =
sqlContext.sparkContext
.parallelize(0 to 1000)
@@ -158,6 +157,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
checkParquetFile(data)
}
+ test("array and double") {
+ val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble)))
+ checkParquetFile(data)
+ }
+
test("struct") {
val data = (1 to 4).map(i => Tuple1((i, s"val_$i")))
withParquetDataFrame(data) { df =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
index 171a656f0e01e..d0bfcde7e032b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
@@ -24,26 +24,109 @@ import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
- lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
+ val sqlContext = TestSQLContext
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
*/
- private def testSchema[T <: Product: ClassTag: TypeTag](
- testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = {
- test(testName) {
- val actual = ParquetTypesConverter.convertFromAttributes(
- ScalaReflection.attributesFor[T], isThriftDerived)
- val expected = MessageTypeParser.parseMessageType(messageType)
+ protected def testSchemaInference[T <: Product: ClassTag: TypeTag](
+ testName: String,
+ messageType: String,
+ binaryAsString: Boolean = true,
+ int96AsTimestamp: Boolean = true,
+ followParquetFormatSpec: Boolean = false,
+ isThriftDerived: Boolean = false): Unit = {
+ testSchema(
+ testName,
+ StructType.fromAttributes(ScalaReflection.attributesFor[T]),
+ messageType,
+ binaryAsString,
+ int96AsTimestamp,
+ followParquetFormatSpec,
+ isThriftDerived)
+ }
+
+ protected def testParquetToCatalyst(
+ testName: String,
+ sqlSchema: StructType,
+ parquetSchema: String,
+ binaryAsString: Boolean = true,
+ int96AsTimestamp: Boolean = true,
+ followParquetFormatSpec: Boolean = false,
+ isThriftDerived: Boolean = false): Unit = {
+ val converter = new CatalystSchemaConverter(
+ assumeBinaryIsString = binaryAsString,
+ assumeInt96IsTimestamp = int96AsTimestamp,
+ followParquetFormatSpec = followParquetFormatSpec)
+
+ test(s"sql <= parquet: $testName") {
+ val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema))
+ val expected = sqlSchema
+ assert(
+ actual === expected,
+ s"""Schema mismatch.
+ |Expected schema: ${expected.json}
+ |Actual schema: ${actual.json}
+ """.stripMargin)
+ }
+ }
+
+ protected def testCatalystToParquet(
+ testName: String,
+ sqlSchema: StructType,
+ parquetSchema: String,
+ binaryAsString: Boolean = true,
+ int96AsTimestamp: Boolean = true,
+ followParquetFormatSpec: Boolean = false,
+ isThriftDerived: Boolean = false): Unit = {
+ val converter = new CatalystSchemaConverter(
+ assumeBinaryIsString = binaryAsString,
+ assumeInt96IsTimestamp = int96AsTimestamp,
+ followParquetFormatSpec = followParquetFormatSpec)
+
+ test(s"sql => parquet: $testName") {
+ val actual = converter.convert(sqlSchema)
+ val expected = MessageTypeParser.parseMessageType(parquetSchema)
actual.checkContains(expected)
expected.checkContains(actual)
}
}
- testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])](
+ protected def testSchema(
+ testName: String,
+ sqlSchema: StructType,
+ parquetSchema: String,
+ binaryAsString: Boolean = true,
+ int96AsTimestamp: Boolean = true,
+ followParquetFormatSpec: Boolean = false,
+ isThriftDerived: Boolean = false): Unit = {
+
+ testCatalystToParquet(
+ testName,
+ sqlSchema,
+ parquetSchema,
+ binaryAsString,
+ int96AsTimestamp,
+ followParquetFormatSpec,
+ isThriftDerived)
+
+ testParquetToCatalyst(
+ testName,
+ sqlSchema,
+ parquetSchema,
+ binaryAsString,
+ int96AsTimestamp,
+ followParquetFormatSpec,
+ isThriftDerived)
+ }
+}
+
+class ParquetSchemaInferenceSuite extends ParquetSchemaTest {
+ testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])](
"basic types",
"""
|message root {
@@ -54,9 +137,10 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
| required double _5;
| optional binary _6;
|}
- """.stripMargin)
+ """.stripMargin,
+ binaryAsString = false)
- testSchema[(Byte, Short, Int, Long, java.sql.Date)](
+ testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)](
"logical integral types",
"""
|message root {
@@ -68,27 +152,79 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
|}
""".stripMargin)
- // Currently String is the only supported logical binary type.
- testSchema[Tuple1[String]](
- "binary logical types",
+ testSchemaInference[Tuple1[String]](
+ "string",
"""
|message root {
| optional binary _1 (UTF8);
|}
+ """.stripMargin,
+ binaryAsString = true)
+
+ testSchemaInference[Tuple1[Seq[Int]]](
+ "non-nullable array - non-standard",
+ """
+ |message root {
+ | optional group _1 (LIST) {
+ | repeated int32 element;
+ | }
+ |}
""".stripMargin)
- testSchema[Tuple1[Seq[Int]]](
- "array",
+ testSchemaInference[Tuple1[Seq[Int]]](
+ "non-nullable array - standard",
+ """
+ |message root {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | required int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchemaInference[Tuple1[Seq[Integer]]](
+ "nullable array - non-standard",
"""
|message root {
| optional group _1 (LIST) {
- | repeated int32 array;
+ | repeated group bag {
+ | optional int32 element;
+ | }
| }
|}
""".stripMargin)
- testSchema[Tuple1[Map[Int, String]]](
- "map",
+ testSchemaInference[Tuple1[Seq[Integer]]](
+ "nullable array - standard",
+ """
+ |message root {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchemaInference[Tuple1[Map[Int, String]]](
+ "map - standard",
+ """
+ |message root {
+ | optional group _1 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | optional binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchemaInference[Tuple1[Map[Int, String]]](
+ "map - non-standard",
"""
|message root {
| optional group _1 (MAP) {
@@ -100,7 +236,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
|}
""".stripMargin)
- testSchema[Tuple1[Pair[Int, String]]](
+ testSchemaInference[Tuple1[Pair[Int, String]]](
"struct",
"""
|message root {
@@ -109,20 +245,21 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
| optional binary _2 (UTF8);
| }
|}
- """.stripMargin)
+ """.stripMargin,
+ followParquetFormatSpec = true)
- testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]](
- "deeply nested type",
+ testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]](
+ "deeply nested type - non-standard",
"""
|message root {
- | optional group _1 (MAP) {
- | repeated group map (MAP_KEY_VALUE) {
+ | optional group _1 (MAP_KEY_VALUE) {
+ | repeated group map {
| required int32 key;
| optional group value {
| optional binary _1 (UTF8);
| optional group _2 (LIST) {
| repeated group bag {
- | optional group array {
+ | optional group element {
| required int32 _1;
| required double _2;
| }
@@ -134,43 +271,76 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
|}
""".stripMargin)
- testSchema[(Option[Int], Map[Int, Option[Double]])](
- "optional types",
+ testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]](
+ "deeply nested type - standard",
"""
|message root {
- | optional int32 _1;
- | optional group _2 (MAP) {
- | repeated group map (MAP_KEY_VALUE) {
+ | optional group _1 (MAP) {
+ | repeated group key_value {
| required int32 key;
- | optional double value;
+ | optional group value {
+ | optional binary _1 (UTF8);
+ | optional group _2 (LIST) {
+ | repeated group list {
+ | optional group element {
+ | required int32 _1;
+ | required double _2;
+ | }
+ | }
+ | }
+ | }
| }
| }
|}
- """.stripMargin)
+ """.stripMargin,
+ followParquetFormatSpec = true)
- // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated
- // as expected from attributes
- testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])](
- "thrift generated parquet schema",
+ testSchemaInference[(Option[Int], Map[Int, Option[Double]])](
+ "optional types",
"""
|message root {
- | optional binary _1 (UTF8);
- | optional binary _2 (UTF8);
- | optional binary _3 (UTF8);
- | optional group _4 (LIST) {
- | repeated int32 _4_tuple;
- | }
- | optional group _5 (MAP) {
- | repeated group map (MAP_KEY_VALUE) {
- | required binary key (UTF8);
- | optional group value (LIST) {
- | repeated int32 value_tuple;
- | }
+ | optional int32 _1;
+ | optional group _2 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | optional double value;
| }
| }
|}
- """.stripMargin, isThriftDerived = true)
+ """.stripMargin,
+ followParquetFormatSpec = true)
+ // Parquet files generated by parquet-thrift are already handled by the schema converter, but
+ // let's leave this test here until both read path and write path are all updated.
+ ignore("thrift generated parquet schema") {
+ // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated
+ // as expected from attributes
+ testSchemaInference[(
+ Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])](
+ "thrift generated parquet schema",
+ """
+ |message root {
+ | optional binary _1 (UTF8);
+ | optional binary _2 (UTF8);
+ | optional binary _3 (UTF8);
+ | optional group _4 (LIST) {
+ | repeated int32 _4_tuple;
+ | }
+ | optional group _5 (MAP) {
+ | repeated group map (MAP_KEY_VALUE) {
+ | required binary key (UTF8);
+ | optional group value (LIST) {
+ | repeated int32 value_tuple;
+ | }
+ | }
+ | }
+ |}
+ """.stripMargin,
+ isThriftDerived = true)
+ }
+}
+
+class ParquetSchemaSuite extends ParquetSchemaTest {
test("DataType string parser compatibility") {
// This is the generated string from previous versions of the Spark SQL, using the following:
// val schema = StructType(List(
@@ -180,10 +350,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
"StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))"
// scalastyle:off
- val jsonString =
- """
- |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}
- """.stripMargin
+ val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}"""
// scalastyle:on
val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString)
@@ -277,4 +444,465 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
StructField("secondField", StringType, nullable = true))))
}.getMessage.contains("detected conflicting schemas"))
}
+
+ // =======================================================
+ // Tests for converting Parquet LIST to Catalyst ArrayType
+ // =======================================================
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with nullable element type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with nullable element type - 2",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group element {
+ | optional int32 num;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 1 - standard",
+ StructType(Seq(
+ StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group list {
+ | required int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 2",
+ StructType(Seq(
+ StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group element {
+ | required int32 num;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 3",
+ StructType(Seq(
+ StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated int32 element;
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 4",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(
+ StructType(Seq(
+ StructField("str", StringType, nullable = false),
+ StructField("num", IntegerType, nullable = false))),
+ containsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group element {
+ | required binary str (UTF8);
+ | required int32 num;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(
+ StructType(Seq(
+ StructField("str", StringType, nullable = false))),
+ containsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group array {
+ | required binary str (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(
+ StructType(Seq(
+ StructField("str", StringType, nullable = false))),
+ containsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group f1_tuple {
+ | required binary str (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ // =======================================================
+ // Tests for converting Catalyst ArrayType to Parquet LIST
+ // =======================================================
+
+ testCatalystToParquet(
+ "Backwards-compatibility: LIST with nullable element type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group bag {
+ | optional int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: LIST with non-nullable element type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated group list {
+ | required int32 element;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x",
+ StructType(Seq(
+ StructField(
+ "f1",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (LIST) {
+ | repeated int32 element;
+ | }
+ |}
+ """.stripMargin)
+
+ // ====================================================
+ // Tests for converting Parquet Map to Catalyst MapType
+ // ====================================================
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with non-nullable value type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | required binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with non-nullable value type - 2",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP_KEY_VALUE) {
+ | repeated group map {
+ | required int32 num;
+ | required binary str (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group map (MAP_KEY_VALUE) {
+ | required int32 key;
+ | required binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with nullable value type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | optional binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with nullable value type - 2",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP_KEY_VALUE) {
+ | repeated group map {
+ | required int32 num;
+ | optional binary str (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testParquetToCatalyst(
+ "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group map (MAP_KEY_VALUE) {
+ | required int32 key;
+ | optional binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ // ====================================================
+ // Tests for converting Catalyst MapType to Parquet Map
+ // ====================================================
+
+ testCatalystToParquet(
+ "Backwards-compatibility: MAP with non-nullable value type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | required binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group map (MAP_KEY_VALUE) {
+ | required int32 key;
+ | required binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: MAP with nullable value type - 1 - standard",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | optional binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testCatalystToParquet(
+ "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x",
+ StructType(Seq(
+ StructField(
+ "f1",
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ nullable = true))),
+ """message root {
+ | optional group f1 (MAP) {
+ | repeated group map (MAP_KEY_VALUE) {
+ | required int32 key;
+ | optional binary value (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ // =================================
+ // Tests for conversion for decimals
+ // =================================
+
+ testSchema(
+ "DECIMAL(1, 0) - standard",
+ StructType(Seq(StructField("f1", DecimalType(1, 0)))),
+ """message root {
+ | optional int32 f1 (DECIMAL(1, 0));
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchema(
+ "DECIMAL(8, 3) - standard",
+ StructType(Seq(StructField("f1", DecimalType(8, 3)))),
+ """message root {
+ | optional int32 f1 (DECIMAL(8, 3));
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchema(
+ "DECIMAL(9, 3) - standard",
+ StructType(Seq(StructField("f1", DecimalType(9, 3)))),
+ """message root {
+ | optional int32 f1 (DECIMAL(9, 3));
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchema(
+ "DECIMAL(18, 3) - standard",
+ StructType(Seq(StructField("f1", DecimalType(18, 3)))),
+ """message root {
+ | optional int64 f1 (DECIMAL(18, 3));
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchema(
+ "DECIMAL(19, 3) - standard",
+ StructType(Seq(StructField("f1", DecimalType(19, 3)))),
+ """message root {
+ | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3));
+ |}
+ """.stripMargin,
+ followParquetFormatSpec = true)
+
+ testSchema(
+ "DECIMAL(1, 0) - prior to 1.4.x",
+ StructType(Seq(StructField("f1", DecimalType(1, 0)))),
+ """message root {
+ | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0));
+ |}
+ """.stripMargin)
+
+ testSchema(
+ "DECIMAL(8, 3) - prior to 1.4.x",
+ StructType(Seq(StructField("f1", DecimalType(8, 3)))),
+ """message root {
+ | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3));
+ |}
+ """.stripMargin)
+
+ testSchema(
+ "DECIMAL(9, 3) - prior to 1.4.x",
+ StructType(Seq(StructField("f1", DecimalType(9, 3)))),
+ """message root {
+ | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3));
+ |}
+ """.stripMargin)
+
+ testSchema(
+ "DECIMAL(18, 3) - prior to 1.4.x",
+ StructType(Seq(StructField("f1", DecimalType(18, 3)))),
+ """message root {
+ | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3));
+ |}
+ """.stripMargin)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cf05c6c989655..8021f915bb821 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveCommandStrategy(self),
HiveDDLStrategy,
DDLStrategy,
- TakeOrdered,
+ TakeOrderedAndProject,
ParquetOperations,
InMemoryScans,
ParquetConversion, // Must be before HiveTableScans
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 705f48f1cd9f0..0fd7b3a91d6dd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSer
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.hadoop.io.{NullWritable, Writable}
-import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter}
+import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -194,6 +194,16 @@ private[sql] class OrcRelation(
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ job.getConfiguration match {
+ case conf: JobConf =>
+ conf.setOutputFormat(classOf[OrcOutputFormat])
+ case conf =>
+ conf.setClass(
+ "mapred.output.format.class",
+ classOf[OrcOutputFormat],
+ classOf[MapRedOutputFormat[_, _]])
+ }
+
new OutputWriterFactory {
override def newInstance(
path: String,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index ab443032be20d..b875e52b986ab 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.hive
import java.io.File
+import scala.sys.process.{ProcessLogger, Process}
+
import org.apache.spark._
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.util.{ResetSystemProperties, Utils}
@@ -82,12 +84,18 @@ class HiveSparkSubmitSuite
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- val process = Utils.executeCommand(
+ val process = Process(
Seq("./bin/spark-submit") ++ args,
new File(sparkHome),
- Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
+ "SPARK_TESTING" -> "1",
+ "SPARK_HOME" -> sparkHome
+ ).run(ProcessLogger(
+ (line: String) => { println(s"out> $line") },
+ (line: String) => { println(s"err> $line") }
+ ))
+
try {
- val exitCode = failAfter(120 seconds) { process.waitFor() }
+ val exitCode = failAfter(180 seconds) { process.exitValue() }
if (exitCode != 0) {
fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
index f0f04f8c73fb4..197e9bfb02c4e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
}
assert(numEquals === 1)
}
-
- test("COALESCE with different types") {
- intercept[RuntimeException] {
- TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
- }
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index a2e666586c186..f0aad8dbbe64d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -638,7 +638,7 @@ class SQLQuerySuite extends QueryTest {
test("SPARK-5203 union with different decimal precision") {
Seq.empty[(Decimal, Decimal)]
.toDF("d1", "d2")
- .select($"d1".cast(DecimalType(10, 15)).as("d"))
+ .select($"d1".cast(DecimalType(10, 5)).as("d"))
.registerTempTable("dn")
sql("select d from dn union all select d * 2 from dn")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 5d7cd16c129cd..e8141923a9b5c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -119,6 +119,8 @@ class SimpleTextRelation(
}
override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory {
+ job.setOutputFormatClass(classOf[TextOutputFormat[_, _]])
+
override def newInstance(
path: String,
dataSchema: StructType,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index e0d8277a8ed3f..afecf9675e11f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -17,10 +17,16 @@
package org.apache.spark.sql.sources
+import scala.collection.JavaConversions._
+
import java.io.File
import com.google.common.io.Files
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
+import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
@@ -476,7 +482,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
// more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this
// requirement. We probably want to move this test case to spark-integration-tests or spark-perf
// later.
- test("SPARK-8406: Avoids name collision while writing Parquet files") {
+ test("SPARK-8406: Avoids name collision while writing files") {
withTempPath { dir =>
val path = dir.getCanonicalPath
sqlContext
@@ -497,6 +503,81 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
}
}
+
+ test("SPARK-8578 specified custom output committer will not be used to append data") {
+ val clonedConf = new Configuration(configuration)
+ try {
+ val df = sqlContext.range(1, 10).toDF("i")
+ withTempPath { dir =>
+ df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
+ configuration.set(
+ SQLConf.OUTPUT_COMMITTER_CLASS.key,
+ classOf[AlwaysFailOutputCommitter].getName)
+ // Since Parquet has its own output committer setting, also set it
+ // to AlwaysFailParquetOutputCommitter at here.
+ configuration.set("spark.sql.parquet.output.committer.class",
+ classOf[AlwaysFailParquetOutputCommitter].getName)
+ // Because there data already exists,
+ // this append should succeed because we will use the output committer associated
+ // with file format and AlwaysFailOutputCommitter will not be used.
+ df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
+ checkAnswer(
+ sqlContext.read
+ .format(dataSourceName)
+ .option("dataSchema", df.schema.json)
+ .load(dir.getCanonicalPath),
+ df.unionAll(df))
+
+ // This will fail because AlwaysFailOutputCommitter is used when we do append.
+ intercept[Exception] {
+ df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath)
+ }
+ }
+ withTempPath { dir =>
+ configuration.set(
+ SQLConf.OUTPUT_COMMITTER_CLASS.key,
+ classOf[AlwaysFailOutputCommitter].getName)
+ // Since Parquet has its own output committer setting, also set it
+ // to AlwaysFailParquetOutputCommitter at here.
+ configuration.set("spark.sql.parquet.output.committer.class",
+ classOf[AlwaysFailParquetOutputCommitter].getName)
+ // Because there is no existing data,
+ // this append will fail because AlwaysFailOutputCommitter is used when we do append
+ // and there is no existing data.
+ intercept[Exception] {
+ df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
+ }
+ }
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ configuration.clear()
+ clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ }
+ }
+}
+
+// This class is used to test SPARK-8578. We should not use any custom output committer when
+// we actually append data to an existing dir.
+class AlwaysFailOutputCommitter(
+ outputPath: Path,
+ context: TaskAttemptContext)
+ extends FileOutputCommitter(outputPath, context) {
+
+ override def commitJob(context: JobContext): Unit = {
+ sys.error("Intentional job commitment failure for testing purpose.")
+ }
+}
+
+// This class is used to test SPARK-8578. We should not use any custom output committer when
+// we actually append data to an existing dir.
+class AlwaysFailParquetOutputCommitter(
+ outputPath: Path,
+ context: TaskAttemptContext)
+ extends ParquetOutputCommitter(outputPath, context) {
+
+ override def commitJob(context: JobContext): Unit = {
+ sys.error("Intentional job commitment failure for testing purpose.")
+ }
}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -638,4 +719,25 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
}
+
+ test("SPARK-8604: Parquet data source should write summary file while doing appending") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val df = sqlContext.range(0, 5)
+ df.write.mode(SaveMode.Overwrite).parquet(path)
+
+ val summaryPath = new Path(path, "_metadata")
+ val commonSummaryPath = new Path(path, "_common_metadata")
+
+ val fs = summaryPath.getFileSystem(configuration)
+ fs.delete(summaryPath, true)
+ fs.delete(commonSummaryPath, true)
+
+ df.write.mode(SaveMode.Append).parquet(path)
+ checkAnswer(sqlContext.read.parquet(path), df.unionAll(df))
+
+ assert(fs.exists(summaryPath))
+ assert(fs.exists(commonSummaryPath))
+ }
+ }
}
diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js
index 75251f493ad22..4886b68eeaf76 100644
--- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js
+++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js
@@ -31,6 +31,8 @@ var maxXForHistogram = 0;
var histogramBinCount = 10;
var yValueFormat = d3.format(",.2f");
+var unitLabelYOffset = -10;
+
// Show a tooltip "text" for "node"
function showBootstrapTooltip(node, text) {
$(node).tooltip({title: text, trigger: "manual", container: "body"});
@@ -133,7 +135,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) {
.attr("class", "y axis")
.call(yAxis)
.append("text")
- .attr("transform", "translate(0," + (-3) + ")")
+ .attr("transform", "translate(0," + unitLabelYOffset + ")")
.text(unitY);
@@ -223,10 +225,10 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) {
.style("border-left", "0px solid white");
var margin = {top: 20, right: 30, bottom: 30, left: 10};
- var width = 300 - margin.left - margin.right;
+ var width = 350 - margin.left - margin.right;
var height = 150 - margin.top - margin.bottom;
- var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width]);
+ var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]);
var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]);
var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5);
@@ -248,7 +250,7 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) {
.attr("class", "x axis")
.call(xAxis)
.append("text")
- .attr("transform", "translate(" + (margin.left + width - 40) + ", 15)")
+ .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")")
.text("#batches");
svg.append("g")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
index 4ee7a486e370b..87af902428ec8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
@@ -310,7 +310,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
|
Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed) |
- Histograms |
+ Histograms |
@@ -456,7 +456,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
{receiverActive} |
{receiverLocation} |
{receiverLastErrorTime} |
- {receiverLastError} |
+ {receiverLastError} |
|