From 1bb5d716c0351cd0b4c11b397fd778f30db39bd9 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 3 Jun 2015 00:59:50 +0800 Subject: [PATCH 01/18] [SPARK-8037] [SQL] Ignores files whose name starts with dot in HadoopFsRelation Author: Cheng Lian Closes #6581 from liancheng/spark-8037 and squashes the following commits: d08e97b [Cheng Lian] Ignores files whose name starts with dot in HadoopFsRelation --- .../spark/sql/sources/PartitioningUtils.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 11 +++++++---- .../ParquetPartitionDiscoverySuite.scala | 19 ++++++++++++++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index dafdf0f8b4564..c4c99de5a38dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -187,7 +187,7 @@ private[sql] object PartitioningUtils { Seq.empty } else { assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") + val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") s"Conflicting partition column names detected:\n$list" }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b1b997c030a60..c4ffa8de52640 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -379,10 +379,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] def refresh(): Unit = { - // We don't filter files/directories whose name start with "_" or "." here, as specific data - // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files). - // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave - // partial/corrupted data files there. + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { if (status.getPath.getName.toLowerCase == "_temporary") { Set.empty @@ -400,6 +400,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") } val files = statuses.filterNot(_.isDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index f231589e9674d..3b29979452ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.parquet import java.io.File import java.math.BigInteger -import java.sql.{Timestamp, Date} +import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal @@ -432,4 +433,20 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(read.load(dir.toString).select(fields: _*), row) } } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + } + } } From 0071bd8d31f13abfe73b9d141a818412d374dce0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 2 Jun 2015 11:20:33 -0700 Subject: [PATCH 02/18] [SPARK-8015] [FLUME] Remove Guava dependency from flume-sink. The minimal change would be to disable shading of Guava in the module, and rely on the transitive dependency from other libraries instead. But since Guava's use is so localized, I think it's better to just not use it instead, so I replaced that code and removed all traces of Guava from the module's build. Author: Marcelo Vanzin Closes #6555 from vanzin/SPARK-8015 and squashes the following commits: c0ceea8 [Marcelo Vanzin] Add comments about dependency management. c38228d [Marcelo Vanzin] Add guava dep in test scope. b7a0349 [Marcelo Vanzin] Add libthrift exclusion. 6e0942d [Marcelo Vanzin] Add comment in pom. 2d79260 [Marcelo Vanzin] [SPARK-8015] [flume] Remove Guava dependency from flume-sink. --- external/flume-sink/pom.xml | 39 +++++++++++++++++++ .../flume/sink/SparkAvroCallbackHandler.scala | 4 +- .../flume/sink/SparkSinkThreadFactory.scala | 35 +++++++++++++++++ .../streaming/flume/sink/SparkSinkSuite.scala | 6 +-- 4 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..71f2b6fe18bd1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -42,15 +42,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..dc2a4ab138e18 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel import org.apache.commons.lang3.RandomStringUtils @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala new file mode 100644 index 0000000000000..845fc8debda75 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong + +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } + +} diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 605b3fe71017f..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context @@ -194,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) From ad06727fe985ca243ebdaaba55cd7d35a4749d0a Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 2 Jun 2015 12:38:14 -0700 Subject: [PATCH 03/18] [SPARK-7985] [ML] [MLlib] [Docs] Remove "fittingParamMap" references. Updating ML Doc "Estimator, Transformer, and Param" examples. Updating ML Doc's *"Estimator, Transformer, and Param"* example to use `model.extractParamMap` instead of `model.fittingParamMap`, which no longer exists. mengxr, I believe this addresses (part of) the *update documentation* TODO list item from [PR 5820](https://github.com/apache/spark/pull/5820). Author: Mike Dusenberry Closes #6514 from dusenberrymw/Fix_ML_Doc_Estimator_Transformer_Param_Example and squashes the following commits: 6366e1f [Mike Dusenberry] Updating instances of model.extractParamMap to model.parent.extractParamMap, since the Params of the parent Estimator could possibly differ from thos of the Model. d850e0e [Mike Dusenberry] Removing all references to "fittingParamMap" throughout Spark, since it has been removed. 0480304 [Mike Dusenberry] Updating the ML Doc "Estimator, Transformer, and Param" Java example to use model.extractParamMap() instead of model.fittingParamMap(), which no longer exists. 7d34939 [Mike Dusenberry] Updating ML Doc "Estimator, Transformer, and Param" example to use model.extractParamMap instead of model.fittingParamMap, which no longer exists. --- docs/ml-guide.md | 8 ++++---- .../apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../spark/ml/regression/RandomForestRegressor.scala | 2 +- .../ml/classification/DecisionTreeClassifierSuite.scala | 2 +- .../spark/ml/classification/GBTClassifierSuite.scala | 2 +- .../ml/classification/RandomForestClassifierSuite.scala | 2 +- .../spark/ml/regression/DecisionTreeRegressorSuite.scala | 2 +- .../apache/spark/ml/regression/GBTRegressorSuite.scala | 2 +- .../spark/ml/regression/RandomForestRegressorSuite.scala | 2 +- 11 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5f50ed7990f1..4eb622d4b95e8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF) // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.fittingParamMap) +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training.toDF, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.fittingParamMap) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. val test = sc.parallelize(Seq( @@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training); // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); @@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d8592eb2d947d..62f4b51f770e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -208,7 +208,7 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 67600ebd7b38e..852a67e066322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -170,7 +170,7 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 69f4f5414c8c6..b7e374bb6cb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ae767a17329d2..49a1f7ce8c995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -152,7 +152,7 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } new RandomForestRegressionModel(parent.uid, newTrees) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 40554f6ef94a8..ae40b0b8ff854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -265,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 09327051621e0..1302da3c373ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -127,7 +127,7 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index f699d0c374d2f..eee9355a67be3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -157,7 +157,7 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1182b89a8e3aa..33aa9d0d62343 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -82,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index f8a1469fee313..98fb3d3f5f22c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -128,7 +128,7 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 78911560945a2..b24ecaa57c89b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -113,7 +113,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) From 686a45f0b9c50ede2a80854ed6a155ee8a9a4f5c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 2 Jun 2015 13:32:13 -0700 Subject: [PATCH 04/18] [SPARK-8014] [SQL] Avoid premature metadata discovery when writing a HadoopFsRelation with a save mode other than Append The current code references the schema of the DataFrame to be written before checking save mode. This triggers expensive metadata discovery prematurely. For save mode other than `Append`, this metadata discovery is useless since we either ignore the result (for `Ignore` and `ErrorIfExists`) or delete existing files (for `Overwrite`) later. This PR fixes this issue by deferring metadata discovery after save mode checking. Author: Cheng Lian Closes #6583 from liancheng/spark-8014 and squashes the following commits: 1aafabd [Cheng Lian] Updates comments 088abaa [Cheng Lian] Avoids schema merging and partition discovery when data schema and partition schema are defined 8fbd93f [Cheng Lian] Fixes SPARK-8014 --- .../apache/spark/sql/parquet/newParquet.scala | 2 +- .../apache/spark/sql/sources/commands.scala | 20 +++++-- .../org/apache/spark/sql/sources/ddl.scala | 16 ++--- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../sql/sources/hadoopFsRelationSuites.scala | 59 ++++++++++++++----- 5 files changed, 67 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e439a18ac43aa..824ae36968c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -190,7 +190,7 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = metadataCache.dataSchema + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) override private[sql] def refresh(): Unit = { super.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3132067d562f6..71f016b1f14de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -30,9 +30,10 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} @@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation( // We create a DataFrame by applying the schema of relation to the data to make sure. // We are writing data based on the expected schema, - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.createDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, + relation.schema, + needsConversion = false) + } val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 22587f5a1c6f1..20afd60cb7767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand @@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) - // For partitioned relation r, r.schema's column ordering is different with the column - // ordering of data.logicalPlan. We need a Project to adjust the ordering. - // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to - // the data. - val project = - Project( - r.schema.map(field => new UnresolvedAttribute(Seq(field.name))), - data.logicalPlan) - + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( InsertIntoHadoopFsRelation( r, - project, + data.logicalPlan, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c4ffa8de52640..f5bd2d2941ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -503,7 +503,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index af36fa6f1faae..74095426741e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.sources +import java.io.File + +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkFunSuite} @@ -453,6 +456,20 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + } + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -534,20 +551,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } - test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { - val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - - df.write - .format("parquet") - .mode(SaveMode.Overwrite) - .partitionBy("c", "a") - .saveAsTable("t") - - withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) - } - } - test("SPARK-7868: _temporary directories should be ignored") { withTempPath { dir => val df = Seq("a", "b", "c").zipWithIndex.toDF() @@ -563,4 +566,32 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } } From 605ddbb27c8482fc0107b21c19d4e4ae19348f35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Jun 2015 13:38:06 -0700 Subject: [PATCH 05/18] [SPARK-8038] [SQL] [PYSPARK] fix Column.when() and otherwise() Thanks ogirardot, closes #6580 cc rxin JoshRosen Author: Davies Liu Closes #6590 from davies/when and squashes the following commits: c0f2069 [Davies Liu] fix Column.when() and otherwise() --- python/pyspark/sql/column.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8dc5039f587f0..1ecec5b126505 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ """ return (self >= lowerBound) & (self <= upperBound) @@ -328,12 +336,20 @@ def when(self, condition, value): :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ """ - sc = SparkContext._active_spark_context if not isinstance(condition, Column): raise TypeError("condition should be a Column") v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) + jc = self._jc.when(condition._jc, v) return Column(jc) @since(1.4) @@ -345,9 +361,18 @@ def otherwise(self, value): See :func:`pyspark.sql.functions.when` for example usage. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ """ v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) + jc = self._jc.otherwise(v) return Column(jc) @since(1.4) From 89f21f66b5549524d1a6e4fb576a4f80d9fef903 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 2 Jun 2015 16:51:17 -0700 Subject: [PATCH 06/18] [SPARK-8049] [MLLIB] drop tmp col from OneVsRest output The temporary column should be dropped after we get the prediction column. harsha2010 Author: Xiangrui Meng Closes #6592 from mengxr/SPARK-8049 and squashes the following commits: 1d89107 [Xiangrui Meng] use SparkFunSuite 6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output --- .../org/apache/spark/ml/classification/OneVsRest.scala | 1 + .../apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7b726da388075..825f9ed1b54b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index f439f3261f06f..1d04ccb509057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { From 5cd6a63d9692d153751747e0293dc030d73a6194 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 2 Jun 2015 17:07:13 -0700 Subject: [PATCH 07/18] [SQL] [TEST] [MINOR] Follow-up of PR #6493, use Guava API to ensure Java 6 friendliness This is a follow-up of PR #6493, which has been reverted in branch-1.4 because it uses Java 7 specific APIs and breaks Java 6 build. This PR replaces those APIs with equivalent Guava ones to ensure Java 6 friendliness. cc andrewor14 pwendell, this should also be back ported to branch-1.4. Author: Cheng Lian Closes #6547 from liancheng/override-log4j and squashes the following commits: c900cfd [Cheng Lian] Addresses Shixiong's comment 72da795 [Cheng Lian] Uses Guava API to ensure Java 6 friendliness --- .../sql/hive/thriftserver/HiveThriftServer2Suites.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index da511ebd05ad2..a93a3dee43511 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -29,6 +27,8 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -441,13 +441,14 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl val tempLog4jConf = Utils.createTempDir().getCanonicalPath Files.write( - Paths.get(s"$tempLog4jConf/log4j.properties"), """log4j.rootCategory=INFO, console |log4j.appender.console=org.apache.log4j.ConsoleAppender |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin.getBytes(StandardCharsets.UTF_8)) + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) tempLog4jConf + File.pathSeparator + sys.props("java.class.path") } From c3f4c3257194ba34ccd298d13ea1edcfc75f7552 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Tue, 2 Jun 2015 18:53:04 -0700 Subject: [PATCH 08/18] [SPARK-7387] [ML] [DOC] CrossValidator example code in Python Author: Ram Sriharsha Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python --- .../src/main/python/ml/cross_validator.py | 96 +++++++++++++++++++ .../main/python/ml/simple_params_example.py | 4 +- 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/python/ml/cross_validator.py diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 0000000000000..f0ca97c724940 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index 3933d59b52cd1..a9f29dab2d602 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -41,8 +41,8 @@ # prepare training data. # We create an RDD of LabeledPoints and convert them into a DataFrame. - # Spark DataFrames can automatically infer the schema from named tuples - # and LabeledPoint implements __reduce__ to behave like a named tuple. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. training = sc.parallelize([ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), From a86b3e9b9b75f5af4fdbba22e87769058f023204 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 2 Jun 2015 19:12:08 -0700 Subject: [PATCH 09/18] [SPARK-7547] [ML] Scala Example code for ElasticNet This is scala example code for both linear and logistic regression. Python and Java versions are to be added. Author: DB Tsai Closes #6576 from dbtsai/elasticNetExample and squashes the following commits: e7ca406 [DB Tsai] fix test 6bb6d77 [DB Tsai] fix suite and remove duplicated setMaxIter 136e0dd [DB Tsai] address feedback 1ec29d4 [DB Tsai] fix style 9462f5f [DB Tsai] add example --- .../examples/ml/LinearRegressionExample.scala | 142 ++++++++++++++++ .../ml/LogisticRegressionExample.scala | 159 ++++++++++++++++++ .../classification/LogisticRegression.scala | 8 +- .../ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../spark/ml/param/shared/sharedParams.scala | 4 +- .../ml/regression/LinearRegression.scala | 2 +- .../apache/spark/ml/param/ParamsSuite.scala | 6 +- 7 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 0000000000000..b54466fd48bc5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 0000000000000..b12f833ce94c8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d13109d9da4c0..f136bcee9cf2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - /** @group setParam */ + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + * */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1ffb5eddc36bd..8ffbcf0d8bc71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter (>= 0)", isValid = "ParamValidators.gtEq(0)"), - ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index ed08417bd4df8..a0c8ccdac9ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params { private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations (>= 0). + * Param for maximum number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = $(maxIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index fe2a71a331694..70cd8e9e87fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -83,7 +83,7 @@ class LinearRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index f80e7749098a5..96094d7a099aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -27,7 +27,7 @@ class ParamsSuite extends SparkFunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations (>= 0)") + assert(maxIter.doc === "maximum number of iterations (>= 0)") assert(maxIter.parent === uid) assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) @@ -36,7 +36,7 @@ class ParamsSuite extends SparkFunSuite { solver.setMaxIter(5) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === s"${uid}__inputCol") @@ -120,7 +120,7 @@ class ParamsSuite extends SparkFunSuite { intercept[NoSuchElementException](solver.getInputCol) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) From cafd5056e12a15f0ebf8015d52dfab999c4443b8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 2 Jun 2015 22:11:03 -0700 Subject: [PATCH 10/18] [SPARK-7691] [SQL] Refactor CatalystTypeConverter to use type-specific row accessors This patch significantly refactors CatalystTypeConverters to both clean up the code and enable these conversions to work with future Project Tungsten features. At a high level, I've reorganized the code so that all functions dealing with the same type are grouped together into type-specific subclasses of `CatalystTypeConveter`. In addition, I've added new methods that allow the Catalyst Row -> Scala Row conversions to access the Catalyst row's fields through type-specific `getTYPE()` methods rather than the generic `get()` / `Row.apply` methods. This refactoring is a blocker to being able to unit test new operators that I'm developing as part of Project Tungsten, since those operators may output `UnsafeRow` instances which don't support the generic `get()`. The stricter type usage of types here has uncovered some bugs in other parts of Spark SQL: - #6217: DescribeCommand is assigned wrong output attributes in SparkStrategies - #6218: DataFrame.describe() should cast all aggregates to String - #6400: Use output schema, not relation schema, for data source input conversion Spark SQL current has undefined behavior for what happens when you try to create a DataFrame from user-specified rows whose values don't match the declared schema. According to the `createDataFrame()` Scaladoc: > It is important to make sure that the structure of every [[Row]] of the provided RDD matches the provided schema. Otherwise, there will be runtime exception. Given this, it sounds like it's technically not a break of our API contract to fail-fast when the data types don't match. However, there appear to be many cases where we don't fail even though the types don't match. For example, `JavaHashingTFSuite.hasingTF` passes a column of integers values for a "label" column which is supposed to contain floats. This column isn't actually read or modified as part of query processing, so its actual concrete type doesn't seem to matter. In other cases, there could be situations where we have generic numeric aggregates that tolerate being called with different numeric types than the schema specified, but this can be okay due to numeric conversions. In the long run, we will probably want to come up with precise semantics for implicit type conversions / widening when converting Java / Scala rows to Catalyst rows. Until then, though, I think that failing fast with a ClassCastException is a reasonable behavior; this is the approach taken in this patch. Note that certain optimizations in the inbound conversion functions for primitive types mean that we'll probably preserve the old undefined behavior in a majority of cases. Author: Josh Rosen Closes #6222 from JoshRosen/catalyst-converters-refactoring and squashes the following commits: 740341b [Josh Rosen] Optimize method dispatch for primitive type conversions befc613 [Josh Rosen] Add tests to document Option-handling behavior. 5989593 [Josh Rosen] Use new SparkFunSuite base in CatalystTypeConvertersSuite 6edf7f8 [Josh Rosen] Re-add convertToScala(), since a Hive test still needs it 3f7b2d8 [Josh Rosen] Initialize converters lazily so that the attributes are resolved first 6ad0ebb [Josh Rosen] Fix JavaHashingTFSuite ClassCastException 677ff27 [Josh Rosen] Fix null handling bug; add tests. 8033d4c [Josh Rosen] Fix serialization error in UserDefinedGenerator. 85bba9d [Josh Rosen] Fix wrong input data in InMemoryColumnarQuerySuite 9c0e4e1 [Josh Rosen] Remove last use of convertToScala(). ae3278d [Josh Rosen] Throw ClassCastException errors during inbound conversions. 7ca7fcb [Josh Rosen] Comments and cleanup 1e87a45 [Josh Rosen] WIP refactoring of CatalystTypeConverters --- .../spark/ml/feature/JavaHashingTFSuite.java | 6 +- .../sql/catalyst/CatalystTypeConverters.scala | 558 ++++++++++-------- .../sql/catalyst/expressions/generators.scala | 19 +- .../CatalystTypeConvertersSuite.scala | 62 ++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- 5 files changed, 382 insertions(+), 265 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index da2218056307e..599e9cfd23ad4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,9 +55,9 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 1c0ddb5093d17..2e7b4c236d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap @@ -34,197 +37,338 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) - - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: Row, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } + + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: Row, column: Int): ScalaOutputType + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: Row, column: Int): Any = row(column) + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + } + + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 + } + + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) } - new GenericRowWithSchema(ar, structType) + } - case (d: String, _) => - UTF8String(d) + override def toScalaImpl(row: Row, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) + } + + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { - case (d: BigDecimal, _) => - Decimal(d) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) - case (d: java.math.BigDecimal, _) => - Decimal(d) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) + } - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) + } + convertedMap + } - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) + } } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) - } + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 } - - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + new GenericRowWithSchema(ar, structType) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 } + new GenericRowWithSchema(ar, structType) + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScala(row: Row): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 } - - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + } + + private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String(str) + case utf8: UTF8String => utf8 + } + override def toScala(catalystValue: Any): String = catalystValue match { + case null => null + case str: String => str + case utf8: UTF8String => utf8.toString() + } + override def toScalaImpl(row: Row, column: Int): String = row(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + } + + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { + case d: JavaBigDecimal => d + case d: Decimal => d.toJavaBigDecimal + } + } + + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + } + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toCatalyst(scalaValue) + } + + /** + * Creates a converter function that will convert Scala objects to the specified Catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue } + } + convert + } else { + getConverterForType(dataType).toCatalyst } } /** - * Converts Scala objects to catalyst rows / types. + * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row * (It does not support UDT) * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray @@ -238,33 +382,8 @@ object CatalystTypeConverters { * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - - case (r: Row, s: StructType) => - convertRowToScala(r, s) - - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal - - case (i: Int, DateType) => - DateUtils.toJavaDate(i) - - case (s: UTF8String, StringType) => - s.toString() - - case (other, _) => - other + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) } /** @@ -272,82 +391,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null - } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) - } - } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, schema) - } - - /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. - */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toScala } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 634138010fd21..b6191eafba71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -71,12 +71,23 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (Row) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + } + override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 0000000000000..df0f04563edcf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends SparkFunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29e..055453e688e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. From 07c16cb5ba9cb0bfe34e8c0efbf06540a22d4e4e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 2 Jun 2015 22:56:56 -0700 Subject: [PATCH 11/18] [SPARK-8053] [MLLIB] renamed scalingVector to scalingVec I searched the Spark codebase for all occurrences of "scalingVector" CC: mengxr Author: Joseph K. Bradley Closes #6596 from jkbradley/scalingVec-rename and squashes the following commits: d3812f8 [Joseph K. Bradley] renamed scalingVector to scalingVec --- .../spark/ml/feature/ElementwiseProduct.scala | 2 +- .../spark/mllib/feature/ElementwiseProduct.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 3ae1833390152..1e758cb775de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String) * the vector to multiply with input vectors * @group param */ - val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product") + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ def setScalingVec(value: Vector): this.type = set(scalingVec, value) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index b0985baf9b278..d67fe6c3ee4f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._ * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. - * @param scalingVector The values used to scale the reference vector's individual components. + * @param scalingVec The values used to scale the reference vector's individual components. */ @Experimental -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { * @return transformed vector. */ override def transform(vector: Vector): Vector = { - require(vector.size == scalingVector.size, - s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}") + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") vector match { case dv: DenseVector => val values: Array[Double] = dv.values.clone() - val dim = scalingVector.size + val dim = scalingVec.size var i = 0 while (i < dim) { - values(i) *= scalingVector(i) + values(i) *= scalingVec(i) i += 1 } Vectors.dense(values) @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { val dim = values.length var i = 0 while (i < dim) { - values(i) *= scalingVector(indices(i)) + values(i) *= scalingVec(indices(i)) i += 1 } Vectors.sparse(size, indices, values) From ccaa823290cbe859cd224ac0f7071dfd0218b669 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Tue, 2 Jun 2015 22:59:48 -0700 Subject: [PATCH 12/18] [MINOR] make the launcher project name consistent with others I found this by chance while building spark and think it is better to keep its name consistent with other sub-projects (Spark Project *). I am not gonna file JIRA as it is a pretty small issue. Author: WangTaoTheTonic Closes #6603 from WangTaoTheTonic/projName and squashes the following commits: 994b3ba [WangTaoTheTonic] make the project name consistent --- launcher/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa18..cc177d23dff77 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher From 43adbd56114ba80039a23909b0a30d393eaacc62 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 2 Jun 2015 23:15:38 -0700 Subject: [PATCH 13/18] [SPARK-8043] [MLLIB] [DOC] update NaiveBayes and SVM examples in doc jira: https://issues.apache.org/jira/browse/SPARK-8043 I found some issues during testing the save/load examples in markdown Documents, as a part of 1.4 QA plan Author: Yuhao Yang Closes #6584 from hhbyyh/naiveDocExample and squashes the following commits: a01a206 [Yuhao Yang] fix for Gaussian mixture 2fb8b96 [Yuhao Yang] update NaiveBayes and SVM examples in doc --- docs/mllib-clustering.md | 6 +++--- docs/mllib-linear-methods.md | 24 ++++++++++-------------- docs/mllib-naive-bayes.md | 2 +- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dac22f736e8cb..1b088969ddc25 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -249,11 +249,11 @@ public class GaussianMixtureExample { GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel - gmm.save(sc, "myGMMModel") - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + gmm.save(sc.sc(), "myGMMModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); // Output the parameters of the mixture model for(int j=0; j
-The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index acdcc371487f8..bf6d124fd5d8d 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -53,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() From 452eb82dd722e5dfd00ee47bb8b6353933b0016e Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 2 Jun 2015 23:24:47 -0700 Subject: [PATCH 14/18] [SPARK-8032] [PYSPARK] Make version checking for NumPy in MLlib more robust The current checking does version `1.x' is less than `1.4' this will fail if x has greater than 1 digit, since x > 4, however `1.x` < `1.4` It fails in my system since I have version `1.10` :P Author: MechCoder Closes #6579 from MechCoder/np_ver and squashes the following commits: 15430f8 [MechCoder] fix syntax error 893fb7e [MechCoder] remove equal to e35f0d4 [MechCoder] minor e89376c [MechCoder] Better checking 22703dd [MechCoder] [SPARK-8032] Make version checking for NumPy in MLlib more robust --- python/pyspark/mllib/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index b11aed2c3afda..acba3a717d21a 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -23,7 +23,9 @@ # MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', From ce320cb2dbf28825f80795ce569735888f98d6e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Jun 2015 00:23:34 -0700 Subject: [PATCH 15/18] [SPARK-8060] Improve DataFrame Python test coverage and documentation. Author: Reynold Xin Closes #6601 from rxin/python-read-write-test-and-doc and squashes the following commits: baa8ad5 [Reynold Xin] Code review feedback. f081d47 [Reynold Xin] More documentation updates. c9902fa [Reynold Xin] [SPARK-8060] Improve DataFrame Python reader/writer interface doc and testing. --- .rat-excludes | 1 + python/pyspark/sql/__init__.py | 13 +- python/pyspark/sql/context.py | 89 +++---- python/pyspark/sql/dataframe.py | 82 +++---- python/pyspark/sql/readwriter.py | 217 ++++++++---------- python/pyspark/sql/tests.py | 2 + .../sql/parquet_partitioned/_SUCCESS | 0 .../sql/parquet_partitioned/_common_metadata | Bin 0 -> 210 bytes .../sql/parquet_partitioned/_metadata | Bin 0 -> 743 bytes .../day=1/.part-r-00008.gz.parquet.crc | Bin 0 -> 12 bytes .../month=9/day=1/part-r-00008.gz.parquet | Bin 0 -> 322 bytes .../day=25/.part-r-00002.gz.parquet.crc | Bin 0 -> 12 bytes .../day=25/.part-r-00004.gz.parquet.crc | Bin 0 -> 12 bytes .../month=10/day=25/part-r-00002.gz.parquet | Bin 0 -> 343 bytes .../month=10/day=25/part-r-00004.gz.parquet | Bin 0 -> 343 bytes .../day=26/.part-r-00005.gz.parquet.crc | Bin 0 -> 12 bytes .../month=10/day=26/part-r-00005.gz.parquet | Bin 0 -> 333 bytes .../day=1/.part-r-00007.gz.parquet.crc | Bin 0 -> 12 bytes .../month=9/day=1/part-r-00007.gz.parquet | Bin 0 -> 343 bytes python/test_support/sql/people.json | 3 + 20 files changed, 180 insertions(+), 227 deletions(-) create mode 100644 python/test_support/sql/parquet_partitioned/_SUCCESS create mode 100644 python/test_support/sql/parquet_partitioned/_common_metadata create mode 100644 python/test_support/sql/parquet_partitioned/_metadata create mode 100644 python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc create mode 100644 python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet create mode 100644 python/test_support/sql/people.json diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09d..8f2722cbd001f 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,3 +82,4 @@ local-1426633911242/* local-1430917381534/* DESCRIPTION NAMESPACE +test_support/* diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 726d288d97b2e..ad9c891ba1c04 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -45,11 +45,20 @@ def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + def deco(f): - f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) return f return deco + from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column @@ -58,7 +67,9 @@ def deco(f): from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec + __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 22f6257dfe02d..9fdf43c3e6eb5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -124,7 +124,10 @@ def getConf(self, key, defaultValue): @property @since("1.3.1") def udf(self): - """Returns a :class:`UDFRegistration` for UDF registration.""" + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ return UDFRegistration(self) @since(1.4) @@ -138,7 +141,7 @@ def range(self, start, end, step=1, numPartitions=None): :param end: the end value (exclusive) :param step: the incremental step (default: 1) :param numPartitions: the number of partitions of the DataFrame - :return: A new DataFrame + :return: :class:`DataFrame` >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] @@ -195,8 +198,8 @@ def _inferSchema(self, rdd, samplingRatio=None): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) @@ -219,7 +222,7 @@ def inferSchema(self, rdd, samplingRatio=None): """ .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -262,6 +265,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -359,18 +363,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") - @since(1.0) def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): @@ -378,39 +379,15 @@ def parquetFile(self, *paths): jdf = self._ssql_ctx.parquetFile(jpaths) return DataFrame(jdf, self) - @since(1.0) def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.jsonFile(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.jsonFile(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: @@ -462,21 +439,16 @@ def func(iterator): df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self) - @since(1.3) def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned DataFrame. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ + warnings.warn("load is deprecated. Use read.load() instead.") return self.read.load(path, source, schema, **options) @since(1.3) - def createExternalTable(self, tableName, path=None, source=None, - schema=None, **options): + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -487,6 +459,8 @@ def createExternalTable(self, tableName, path=None, source=None, Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. + + :return: :class:`DataFrame` """ if path is not None: options["path"] = path @@ -508,6 +482,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -519,6 +495,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -536,6 +514,9 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -550,7 +531,8 @@ def tables(self, dbName=None): def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. - If ``dbName`` is not specified, the current database will be used. + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() @@ -585,8 +567,7 @@ def read(self): Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - >>> sqlContext.read - + :return: :class:`DataFrameReader` """ return DataFrameReader(self) @@ -644,10 +625,14 @@ def register(self, name, f, returnType=StringType()): def _test(): + import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a82b6b87c413e..7673153abe0e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -44,7 +44,7 @@ class DataFrame(object): A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: - people = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. @@ -56,8 +56,8 @@ class DataFrame(object): A more concrete example:: # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) @@ -120,21 +120,12 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - @since(1.3) def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. - Files that are written out using this method can be read back in as - a :class:`DataFrame` using :func:`SQLContext.parquetFile`. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path) @since(1.3) @@ -151,69 +142,45 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - @since(1.3) def registerAsTable(self, name): - """DEPRECATED: use :func:`registerTempTable` instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name) - @since(1.3) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. - Optionally overwriting any existing data. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") self.write.insertInto(tableName, overwrite) - @since(1.3) def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") self.write.saveAsTable(tableName, source, mode, **options) @since(1.3) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") return self.write.save(path, source, mode, **options) @property @since(1.4) def write(self): """ - Interface for saving the content of the :class:`DataFrame` out - into external storage. - - :return :class:`DataFrameWriter` + Interface for saving the content of the :class:`DataFrame` out into external storage. - .. note:: Experimental - - >>> df.write - + :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) @@ -636,6 +603,9 @@ def describe(self, *cols): This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> df.describe().show() +-------+---+ |summary|age| @@ -653,9 +623,11 @@ def describe(self, *cols): @ignore_unicode_prefix @since(1.3) def head(self, n=None): - """ - Returns the first ``n`` rows as a list of :class:`Row`, - or the first :class:`Row` if ``n`` is ``None.`` + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -1170,8 +1142,8 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. - This function is meant for exploratory data analysis, as we make no guarantee about the - backward compatibility of the schema of the resulting DataFrame. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d17d87419fe3d..f036644acc961 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -45,18 +45,24 @@ def _df(self, jdf): @since(1.4) def format(self, source): - """ - Specifies the input data source format. + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ self._jreader = self._jreader.format(source) return self @since(1.4) def schema(self, schema): - """ - Specifies the input schema. Some data sources (e.g. JSON) can - infer the input schema automatically from data. By specifying - the schema here, the underlying data source can skip the schema + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. :param schema: a StructType object @@ -69,8 +75,7 @@ def schema(self, schema): @since(1.4) def options(self, **options): - """ - Adds input options for the underlying data source. + """Adds input options for the underlying data source. """ for k in options: self._jreader = self._jreader.option(k, options[k]) @@ -84,6 +89,10 @@ def load(self, path=None, format=None, schema=None, **options): :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ if format is not None: self.format(format) @@ -107,31 +116,10 @@ def json(self, path, schema=None): :param path: string, path to the JSON dataset. :param schema: an optional :class:`StructType` for the input schema. - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.read.json(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.read.json(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ if schema is not None: self.schema(schema) @@ -141,10 +129,12 @@ def json(self, path, schema=None): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlContext.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlContext.read.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -152,13 +142,9 @@ def table(self, tableName): def parquet(self, *path): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.read.parquet(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) @@ -221,30 +207,34 @@ def __init__(self, df): @since(1.4) def mode(self, saveMode): - """ - Specifies the behavior when data or table already exists. Options include: + """Specifies the behavior when data or table already exists. + + Options include: * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.mode(saveMode) return self @since(1.4) def format(self, source): - """ - Specifies the underlying output data source. Built-in options include - "parquet", "json", etc. + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.format(source) return self @since(1.4) def options(self, **options): - """ - Adds output options for the underlying data source. + """Adds output options for the underlying data source. """ for k in options: self._jwrite = self._jwrite.option(k, options[k]) @@ -252,12 +242,14 @@ def options(self, **options): @since(1.4) def partitionBy(self, *cols): - """ - Partitions the output by the given columns on the file system. + """Partitions the output by the given columns on the file system. + If specified, the output is laid out on the file system similar to Hive's partitioning scheme. :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] @@ -266,25 +258,23 @@ def partitionBy(self, *cols): @since(1.4) def save(self, path=None, format=None, mode="error", **options): - """ - Saves the contents of the :class:`DataFrame` to a data source. + """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. If ``format`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. - :param path: the path in a Hadoop supported file system :param format: the format used to save - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode).options(**options) if format is not None: @@ -296,8 +286,8 @@ def save(self, path=None, format=None, mode="error", **options): @since(1.4) def insertInto(self, tableName, overwrite=False): - """ - Inserts the content of the :class:`DataFrame` to the specified table. + """Inserts the content of the :class:`DataFrame` to the specified table. + It requires that the schema of the class:`DataFrame` is the same as the schema of the table. @@ -307,8 +297,7 @@ def insertInto(self, tableName, overwrite=False): @since(1.4) def saveAsTable(self, name, format=None, mode="error", **options): - """ - Saves the content of the :class:`DataFrame` as the specified table. + """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the save mode, specified by the `mode` function (default to throwing an exception). @@ -328,67 +317,58 @@ def saveAsTable(self, name, format=None, mode="error", **options): self.mode(mode).options(**options) if format is not None: self.format(format) - return self._jwrite.saveAsTable(name) + self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in JSON format at the - specified path. + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).json(path) + self._jwrite.mode(mode).json(path) @since(1.4) def parquet(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in Parquet format at the - specified path. + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).parquet(path) + self._jwrite.mode(mode).parquet(path) @since(1.4) def jdbc(self, url, table, mode="error", properties={}): - """ - Saves the content of the :class:`DataFrame` to a external database table - via JDBC. - - In the case the table already exists in the external database, - behavior of this function depends on the save mode, specified by the `mode` - function (default to throwing an exception). There are four modes: + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. - :param url: a JDBC URL of the form `jdbc:subprotocol:subname` + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: Name of the table in the external database. - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param properties: JDBC database connection arguments, a list of - arbitrary string tag/value. Normally at least a - "user" and "password" property should be included. + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. """ jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: @@ -398,24 +378,23 @@ def jdbc(self, url, table, mode="error", properties={}): def _test(): import doctest + import os + import tempfile from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 76384d31f1bf4..6e498f0af0af5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -753,8 +753,10 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") except TypeError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 0000000000000000000000000000000000000000..7ef2320651dee5f2a4d588ecf8d0b281a2bc3018 GIT binary patch literal 210 zcmYk1!3u&v5QYcw=+#i_AOju(TauuIw{9J!W6@#L&7^f#ch@4s*Xz03qM*|Z%=dr8 zpKo@l?}W+LRZ<$?0pE+Az!kJ%F~9^uFPsH)sVYKST3i^>Emc>dJ5KD<^~?|@@1$Xd zmekN-KcIQE3^UY5^@YI%&o$$v#_TZQTWe3Bk^F(Rs4OUY&gqF;!bVwwKPhIzI37m` brr(!~MnyNKbS*`ck~LYXVg*kC$ZeY!e^o%_ literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 0000000000000000000000000000000000000000..78a1ca7d38279147e00a826e627f7514d988bd8e GIT binary patch literal 743 zcmb7?&riZI6vyjCA%^RgE^;7EFfke3h76U+ftwc+!pZpKP`3&TWgA`5Oyuq#;D6^o zW7!Xai3#`)eXs57`{ea~hy9VQD!Or7;$bLM1*p}A0!smz(FOq8iT~h>;mEB2-`{-EoU3j+6jrYuY)zEJn-EKp==XmwCF#y_WraHO@felu$%|` z(K_3`IXh{d_L=r}G$4ZdFmoBn%lg_3s`$k}26efUv-!gz5!`pDu$%|Kx;hW}7?X&& z6N+Ow_$iL(tWW^v;TxV&K|CS|yk8=bL=<&VEcn6|$UrYXWnPTB4@g~45h?>03XD` A@&Et; literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..e93f42ed6f35018df1ae76c9746b26bd299c0339 GIT binary patch literal 12 TcmYc;N@ieSU}9KgSR(=e5vKy4 literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..461c382937ecdf13db1003bd2bf8c9a79c510d3e GIT binary patch literal 322 zcmWG=3^EjD5S0?O(-CC?GT1~pWF(j!b27n%7y}Tga9#tj@mb}E=R8}Ei*MIrC7--x>^b35TF#8(m_&~nU@Y! zm{*#UlbDnPQ~}hQs-pxmRLQEkwl=nwK|&g8rEYGKLRo52ab|v=f}x(7owc~qGsSn?8cQ&kxl#F!*yBxTe%WJGx+ zP1zVYBq1`QEMiPz1!7Ye)i`Y6w!s--Yk|^C43aVun)yZPdWi*z$r-77#RZ8)*?Pr= zIeI`wVQFfKUQvFzUT$hhVoG93qC`}+Qb}b&s*;sbaY<2Wa*2|TQd(wePD-(oRdlry z$VEUYFr|Z}Ff%V5s4%ZICnqr}2dDz5HC0CmW~h=?b!}~IErWzK)JomlB89TlqT4=Dm5m#e6wK014Xam4qu!d8N{n7iq=X0 zb^h^Qf1OXk&E@BCXbc2#aBV9oHG%*Q#?Z5KmhmwFF2p|eCytK>__PNc{No_og>K=# zSrg}?YwN_m*4PkW-1E*!d)FUmof*P@{xTZ=z(~N7DFwMN%hUmKBBqXI) zRjf%s)+rZBNy58^>@G6ao`QeDG~bwDUJ1cg!X(Tn56ItA5;kpn-vaO8xAG`cqbIJ) UROX`@J)_4eJ^_{mz{0%r8w^idvH$=8 literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..ae94a15d08c816a6fb60117395462aa640da5026 GIT binary patch literal 12 TcmYc;N@ieSU}AV;?~?@p63+t^ literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..6cb8538aa890475f9128c8025e2e9b0f72b1b785 GIT binary patch literal 333 zcmZut%SyvQ6di}O1}V6jFoVICffR%|SVR}Wjjr4X#ib%-noMhO^5|ruNXTAs=gv>? zC;T(PTdm;2vpI)*&V6vFrOxdZ`77WuvSx<%7tTm8rCnUbWmlR*FZwwx&re5BWS( zI<0wh-SX8nV0}~gCzurr2o{aja;6~xtt#ZdLwVG8-A#w+&U)p3ZbtXY)LB`KCgNBe NnB)+B!ULx8$S(}ES498- literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 0000000000000000000000000000000000000000..58d9bb5fc5883f0b65b637a8deea2672b9fee8a2 GIT binary patch literal 12 TcmYc;N@ieSU}C7hE>;2n5}^Yd literal 0 HcmV?d00001 diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 0000000000000000000000000000000000000000..9b00805481e7b311763d48ff965966b4d09daa99 GIT binary patch literal 343 zcmYjN%SyvQ6ulWjn?*Mw6EYaE45TQ;!6Lc{g1Asz2A7JEX)>*$d37>TB;*6cow)In z`~?5Ok8oNwcsA$2IrpB+4bQKq7%;_`K1Ny$u;n_#kSm$S%U;-^vHN1JNh6*`Q8Z76 zug0@?@&54%+h==UTiU>g_*bSZON9~Ok%t_!;JNSsY(!k*PAnIX$ngLy^5bCBMs{Vt z858TYZ|lXTR@(@O>+F|u!Fa{vd%^08%O$H<8Pj6b2*qUi$a0~0!WDOJTB@EZK?7PV z*~E(abe@VVscCTA()C5!+K~S*m=+5iESfCivrH%SsPO6EQW~^fch`Zl^ILh4%khJd Vby^nVDLY|@GCl&s00{L Date: Wed, 3 Jun 2015 00:47:52 -0700 Subject: [PATCH 16/18] [SPARK-7562][SPARK-6444][SQL] Improve error reporting for expression data type mismatch It seems hard to find a common pattern of checking types in `Expression`. Sometimes we know what input types we need(like `And`, we know we need two booleans), sometimes we just have some rules(like `Add`, we need 2 numeric types which are equal). So I defined a general interface `checkInputDataTypes` in `Expression` which returns a `TypeCheckResult`. `TypeCheckResult` can tell whether this expression passes the type checking or what the type mismatch is. This PR mainly works on apply input types checking for arithmetic and predicate expressions. TODO: apply type checking interface to more expressions. Author: Wenchen Fan Closes #6405 from cloud-fan/6444 and squashes the following commits: b5ff31b [Wenchen Fan] address comments b917275 [Wenchen Fan] rebase 39929d9 [Wenchen Fan] add todo 0808fd2 [Wenchen Fan] make constrcutor of TypeCheckResult private 3bee157 [Wenchen Fan] and decimal type coercion rule for binary comparison 8883025 [Wenchen Fan] apply type check interface to CaseWhen cffb67c [Wenchen Fan] to have resolved call the data type check function 6eaadff [Wenchen Fan] add equal type constraint to EqualTo 3affbd8 [Wenchen Fan] more fixes 654d46a [Wenchen Fan] improve tests e0a3628 [Wenchen Fan] improve error message 1524ff6 [Wenchen Fan] fix style 69ca3fe [Wenchen Fan] add error message and tests c71d02c [Wenchen Fan] fix hive tests 6491721 [Wenchen Fan] use value class TypeCheckResult 7ae76b9 [Wenchen Fan] address comments cb77e4f [Wenchen Fan] Improve error reporting for expression data type mismatch --- .../org/apache/spark/SparkFunSuite.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +- .../catalyst/analysis/HiveTypeCoercion.scala | 132 ++++---- .../catalyst/analysis/TypeCheckResult.scala | 45 +++ .../sql/catalyst/expressions/Expression.scala | 28 +- .../sql/catalyst/expressions/arithmetic.scala | 308 +++++++----------- .../expressions/mathfuncs/binary.scala | 17 +- .../sql/catalyst/expressions/predicates.scala | 226 ++++++------- .../sql/catalyst/optimizer/Optimizer.scala | 4 + .../spark/sql/catalyst/util/DateUtils.scala | 2 +- .../spark/sql/catalyst/util/TypeUtils.scala | 56 ++++ .../org/apache/spark/sql/types/DataType.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 15 +- .../ExpressionTypeCheckingSuite.scala | 143 ++++++++ .../apache/spark/sql/json/InferSchema.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- 17 files changed, 583 insertions(+), 421 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 8cb344332668f..9be9db01c7de9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { * Log the suite name and the test name before and after each test. * * Subclasses should never override this method. If they wish to run - * custom code before and after each test, they should should mix in - * the {{org.scalatest.BeforeAndAfter}} trait instead. + * custom code before and after each test, they should mix in the + * {{org.scalatest.BeforeAndAfter}} trait instead. */ final protected override def withFixture(test: NoArgTest): Outcome = { val testName = test.text diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 193dc6b6546b5..c0695ae369421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,15 +62,17 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( s"Could not resolve window function '$name'. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index edcc918bfe921..b064600e94fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -180,7 +191,7 @@ trait HiveTypeCoercion { case (l, r) if l.dataType != r.dataType => logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = @@ -217,7 +228,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => val newLeft = if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) val newRight = @@ -441,21 +452,18 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) + if e2.dataType == DecimalType.Unlimited => + b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) + case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) + if e1.dataType == DecimalType.Unlimited => + b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -570,7 +578,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonType(a, b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -599,14 +607,9 @@ trait HiveTypeCoercion { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -619,17 +622,13 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } @@ -642,42 +641,33 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } - - case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => - val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = ckw.branches.sliding(2, 2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) - case s => s - }.reduce(_ ++ _) - val transformedKey = if (ckw.key.dataType != commonType) { - Cast(ckw.key, commonType) - } else { - ckw.key - } - CaseKeyWhen(transformedKey, transformedBranches) + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case other => other + }.reduce(_ ++ _) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + } + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 0000000000000..79c3528a522d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. + */ +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean +} + +object TypeCheckResult { + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index adc6505d69cdf..3cf851aec15ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -53,11 +53,12 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -94,12 +95,21 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + + /** + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` + * TODO: we should remove the default implementation and implement it for all + * expressions with proper error message. + */ + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -133,7 +143,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2299d5db6e9f..2ac53f8f6613f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,72 +17,89 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => - override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override def dataType: DataType = child.dataType override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def toString: String = s"-$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } } +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) } - left.dataType } + protected def checkTypesInternal(t: DataType): TypeCheckResult + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -97,88 +114,65 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.plus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.minus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.times(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -198,13 +192,17 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def nullable: Boolean = true - lazy val integral = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -228,7 +226,10 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - lazy val and: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -237,10 +238,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) } /** @@ -249,7 +249,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - lazy val or: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -258,10 +261,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) } /** @@ -270,7 +272,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - lazy val xor: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -279,23 +284,21 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable +case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def toString: String = s"~$child" - lazy val not: (Any) => Any = dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { case ByteType => ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] case ShortType => @@ -304,43 +307,18 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] case LongType => ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - not(evalE) - } - } + protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -361,30 +339,13 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -404,28 +365,3 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MinOf($left, $right)" } - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 01f62ba0442e9..db853a2b97fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -29,17 +29,10 @@ import org.apache.spark.sql.types._ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - override def symbol: String = null override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"$name($left, $right)" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - override def dataType: DataType = DoubleType override def eval(input: Row): Any = { @@ -58,9 +51,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -80,8 +72,7 @@ case class Atan2( } } -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4f422d69c4382..807021d50e8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -171,22 +171,51 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => -} -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) + } + } + + protected def checkTypesInternal(t: DataType): TypeCheckResult override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { + val evalE1 = left.eval(input) + if (evalE1 == null) { null } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + evalInternal(evalE1, evalE2) + } } } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryComparisons must override either eval or evalInternal") +} + +object BinaryComparison { + def unapply(b: BinaryComparison): Option[(Expression, Expression)] = + Some((b.left, b.right)) +} + +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" + + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -194,6 +223,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -210,117 +241,45 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -329,16 +288,20 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess } - trueValue.dataType } + override def dataType: DataType = trueValue.dataType + override def eval(input: Row): Any = { if (true == predicate.eval(input)) { trueValue.eval(input) @@ -364,17 +327,23 @@ trait CaseWhenLike extends Expression { branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - // both then and else val should be considered. + // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") } - valueTypes.head } + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) @@ -395,10 +364,16 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { @@ -441,9 +416,14 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual && - (key +: whenList).map(_.dataType).distinct.size == 1 + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b25fb48f55e2b..5c6379b8d44b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -273,6 +273,10 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index 3f92be4a55d7d..ad649acf536f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast /** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date */ object DateUtils { private val MILLIS_PER_DAY = 86400000 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 0000000000000..0bb12d2039ffc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types._ + +/** + * Helper function to check for valid data types + */ +object TypeUtils { + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 1ba3a2686639f..74677ddfcad65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1b8d18ded2257..7bac97b7894f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -92,8 +92,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index a0798428db094..0df446636ea89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -28,11 +28,11 @@ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -140,13 +140,10 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - // Will remove exception expectation in PR#6405 - intercept[RuntimeException] { - ruleTest(cwc, - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) - ) - } + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) } test("type coercion simplification for equal to") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala new file mode 100644 index 0000000000000..dcb3635c5ccae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 06aa19ef09bd2..565d10247f10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 95eb1174b1dd6..7e1e21f5fbb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type. From 28dbde3874ccdd44b73675938719b69336d23dac Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 3 Jun 2015 13:15:57 +0200 Subject: [PATCH 17/18] [SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile jira: https://issues.apache.org/jira/browse/SPARK-7983 Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT). add a quick check. Author: Yuhao Yang Closes #6538 from hhbyyh/loadSVM and squashes the following commits: 79d9c11 [Yuhao Yang] optimization as respond to comments 4310710 [Yuhao Yang] merge conflict 96460f1 [Yuhao Yang] merge conflict 20a2811 [Yuhao Yang] use require 6e4f8ca [Yuhao Yang] add check for ascending order 9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception 5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM --- .../org/apache/spark/mllib/util/MLUtils.scala | 12 +++++++ .../spark/mllib/util/MLUtilsSuite.scala | 35 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 541f3288b6c43..52d6468a72af7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 734b7babec7be..70219e9ad9d3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), From f1646e1023bd03e27268a8aa2ea11b6cc284075f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 3 Jun 2015 09:26:21 -0700 Subject: [PATCH 18/18] [SPARK-7973] [SQL] Increase the timeout of two CliSuite tests. https://issues.apache.org/jira/browse/SPARK-7973 Author: Yin Huai Closes #6525 from yhuai/SPARK-7973 and squashes the following commits: 763b821 [Yin Huai] Also change the timeout of "Single command with -e" to 2 minutes. e598a08 [Yin Huai] Increase the timeout to 3 minutes. --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3732af7870b93..13b0c5951dddc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -133,7 +133,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -165,7 +165,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute, Seq("--jars", s"$jarFile"))( + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin