diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 773b6ecf582d9..7189f1a260934 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -27,7 +27,21 @@ sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") assign("sqlContext", sqlContext, envir=.GlobalEnv) - cat("\n Welcome to SparkR!") + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad7eb04afcd8c..764578b181422 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -139,6 +139,9 @@ public void write(Iterator> records) throws IOException { @Override public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we ecountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. boolean success = false; try { while (records.hasNext()) { @@ -147,8 +150,19 @@ public void write(scala.collection.Iterator> records) throws IOEx closeAndWriteOutput(); success = true; } finally { - if (!success) { - sorter.cleanupAfterError(); + if (sorter != null) { + try { + sorter.cleanupAfterError(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 83d109115aa5c..10c3eedbf4b46 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -253,6 +253,23 @@ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); } + class PandaException extends RuntimeException { + } + + @Test(expected=PandaException.class) + public void writeFailurePropagates() throws Exception { + class BadRecords extends scala.collection.AbstractIterator> { + @Override public boolean hasNext() { + throw new PandaException(); + } + @Override public Product2 next() { + return null; + } + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(new BadRecords()); + } + @Test public void writeEmptyIterator() throws Exception { final UnsafeShuffleWriter writer = createWriter(true); diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 26c036f6648da..2786e3d2cd6bf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -22,7 +22,7 @@ The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark. All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
@@ -1036,6 +1036,15 @@ for (teenName in collect(teenNames)) {
+
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1054,7 +1063,7 @@ SELECT * FROM parquetTable
-### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in @@ -1108,7 +1117,7 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w `true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end @@ -1208,6 +1217,79 @@ printSchema(df3)
+### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
+ +
+ +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
+ +
+ ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1266,6 +1348,34 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + spark.sql.parquet.output.committer.class + org.apache.parquet.hadoop.
ParquetOutputCommitter
+ +

+ The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
mapreduce.OutputCommitter
. Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

+

+ Note: +

    +
  • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
  • +
  • + This option overrides spark.sql.sources.
    outputCommitterClass
    . +
  • +
+

+

+ Spark SQL comes with a builtin + org.apache.spark.sql.
parquet.DirectParquetOutputCommitter
, which can be more + efficient then the default Parquet output committer when writing data to S3. +

+ + ## JSON Datasets @@ -1445,8 +1555,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. @@ -1794,7 +1904,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - + spark.sql.planner.externalSort false @@ -1889,7 +1999,7 @@ options. #### DataFrame data reader/writer interface Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) -and writing data out (`DataFrame.write`), +and writing data out (`DataFrame.write`), and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 01306545fc7cd..1b1d7299fb496 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -26,7 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams - with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept /** * :: Experimental :: @@ -72,6 +73,14 @@ class LinearRegression(override val uid: String) def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + /** * Set the ElasticNet mixing parameter. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -123,6 +132,7 @@ class LinearRegression(override val uid: String) val numFeatures = summarizer.mean.size val yMean = statCounter.mean val yStd = math.sqrt(statCounter.variance) + // look at glmnet5.m L761 maaaybe that has info // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -142,7 +152,7 @@ class LinearRegression(override val uid: String) val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), featuresStd, featuresMean, effectiveL2RegParam) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { @@ -180,7 +190,7 @@ class LinearRegression(override val uid: String) // The intercept in R's GLMNET is computed using closed form after the coefficients are // converged. See the following discussion for detail. // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) + val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0 if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. @@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] ( * See this discussion for detail. * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet * + * When training with intercept enabled, * The objective function in the scaled space is given by * {{{ * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, @@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] ( * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead + * of the respective means. + * * This can be rewritten as * {{{ * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} @@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] ( * \sum_i w_i^\prime x_i - y / \hat{y} + offset * }}} * + * * Note that the effective weights and offset don't depend on training dataset, * so they can be precomputed. * @@ -301,6 +317,7 @@ private class LeastSquaresAggregator( weights: Vector, labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { @@ -321,7 +338,7 @@ private class LeastSquaresAggregator( } i += 1 } - (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) } private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) @@ -404,6 +421,7 @@ private class LeastSquaresCostFun( data: RDD[(Double, Vector)], labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { @@ -412,7 +430,7 @@ private class LeastSquaresCostFun( val w = Vectors.fromBreeze(weights) val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, - labelMean, featuresStd, featuresMean))( + labelMean, fitIntercept, featuresStd, featuresMean))( seqOp = (c, v) => (c, v) match { case (aggregator, (label, features)) => aggregator.add(label, features) }, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 876a9f9f28242..c98392e310857 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel @@ -981,7 +982,7 @@ private[python] class PythonMLLibAPI extends Serializable { def estimateKernelDensity( sample: JavaRDD[Double], bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { - return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( points.asScala.toArray) } @@ -1000,6 +1001,35 @@ private[python] class PythonMLLibAPI extends Serializable { List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava } + /** + * Wrapper around the generateLinearInput method of LinearDataGenerator. + */ + def generateLinearInputWrapper( + intercept: Double, + weights: JList[Double], + xMean: JList[Double], + xVariance: JList[Double], + nPoints: Int, + seed: Int, + eps: Double): Array[LabeledPoint] = { + LinearDataGenerator.generateLinearInput( + intercept, weights.asScala.toArray, xMean.asScala.toArray, + xVariance.asScala.toArray, nPoints, seed, eps).toArray + } + + /** + * Wrapper around the generateLinearRDD method of LinearDataGenerator. + */ + def generateLinearRDDWrapper( + sc: JavaSparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int, + intercept: Double): JavaRDD[LabeledPoint] = { + LinearDataGenerator.generateLinearRDD( + sc, nexamples, nfeatures, eps, nparts, intercept) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index 308f7f3578e21..a841c5caf0142 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -98,6 +98,8 @@ private[mllib] object NumericParser { } } else if (token == ")") { parsing = false + } else if (token.trim.isEmpty){ + // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number items.append(parseDouble(token)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 732e2c42be144..ad1e9da692ee2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ + @transient var datasetWithoutIntercept: DataFrame = _ /** * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { * * import org.apache.spark.mllib.util.LinearDataGenerator * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + * .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + /** + * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + * training model without intercept + */ + datasetWithoutIntercept = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + } test("linear regression with intercept without regularization") { @@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept without regularization") { + val trainer = (new LinearRegression).setFitIntercept(false) + val model = trainer.fit(dataset) + val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + * intercept = FALSE)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) . + * as.numeric.data.V2. 6.995908 + * as.numeric.data.V3. 5.275131 + */ + val weightsR = Array(6.995908, 5.275131) + + assert(model.intercept ~== 0 relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + /** + * Then again with the data with no intercept: + * > weightsWithoutIntercept + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) . + * as.numeric.data3.V2. 4.70011 + * as.numeric.data3.V3. 7.19943 + */ + val weightsWithoutInterceptR = Array(4.70011, 7.19943) + + assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) + assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) + assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + } + test("linear regression with intercept with L1 regularization") { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) @@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { * > weights * 3 x 1 sparse Matrix of class "dgCMatrix" * s0 - * (Intercept) 6.311546 - * as.numeric.data.V2. 2.123522 - * as.numeric.data.V3. 4.605651 + * (Intercept) 6.24300 + * as.numeric.data.V2. 4.024821 + * as.numeric.data.V3. 6.679841 */ - val interceptR = 6.243000 + val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) @@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + * intercept=FALSE)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) . + * as.numeric.data.V2. 6.299752 + * as.numeric.data.V3. 4.772913 + */ + val interceptR = 0.0 + val weightsR = Array(6.299752, 4.772913) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with L2 regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) @@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + * intercept = FALSE)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) . + * as.numeric.data.V2. 5.522875 + * as.numeric.data.V3. 4.214502 + */ + val interceptR = 0.0 + val weightsR = Array(5.522875, 4.214502) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with ElasticNet regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) @@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression without intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + * intercept=FALSE)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) . + * as.numeric.dataM.V2. 5.673348 + * as.numeric.dataM.V3. 4.322251 + */ + val interceptR = 0.0 + val weightsR = Array(5.673348, 4.322251) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d8364a06de4da..f8d0af8820e64 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -31,6 +31,11 @@ class LabeledPointSuite extends SparkFunSuite { } } + test("parse labeled points with whitespaces") { + val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") + assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) + } + test("parse labeled points with v0.9 format") { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index 8dcb9ba9be108..fa4f74d71b7e7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -37,4 +37,11 @@ class NumericParserSuite extends SparkFunSuite { } } } + + test("parser with whitespaces") { + val s = "(0.0, [1.0, 2.0])" + val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] + assert(parsed(0).asInstanceOf[Double] === 0.0) + assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) + } } diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 5812b72f0aa78..f16bf989f200b 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,8 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - // TODO: Change this once Spark 1.4.0 is released - val previousSparkVersion = "1.4.0-rc4" + val previousSparkVersion = "1.4.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7a748fb5e38bd..f678c69a6dfa9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -53,6 +53,11 @@ object MimaExcludes { // Removing a testing method from a private class ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // NanoTime and CatalystTimestampConverter is only used inside catalyst, diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e01720296fed0..f5f1c9a1a247a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -166,9 +166,8 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove launcher from this list after 1.4.0 allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index bb375d08a3ad6..577ecc947174f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -49,8 +49,8 @@ from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler -from pyspark.mllib.feature import ElementwiseProduct +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct +from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext @@ -1020,6 +1020,24 @@ def collect(rdd): self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + class MLUtilsTests(MLlibTestCase): def test_append_bias(self): data = [2.0, 2.0, 2.0] diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index bb3eb7ed0bbc6..875d3b2d642c6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -279,6 +279,41 @@ def load(cls, sc, path): return cls(java_model) +class LinearDataGenerator(object): + """Utils for generating linear data""" + + @staticmethod + def generateLinearInput(intercept, weights, xMean, xVariance, + nPoints, seed, eps): + """ + :param: intercept bias factor, the term c in X'w + c + :param: weights feature vector, the term w in X'w + c + :param: xMean Point around which the data X is centered. + :param: xVariance Variance of the given data + :param: nPoints Number of points to be generated + :param: seed Random Seed + :param: eps Used to scale the noise. If eps is set high, + the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints + """ + weights = [float(weight) for weight in weights] + xMean = [float(mean) for mean in xMean] + xVariance = [float(var) for var in xVariance] + return list(callMLlibFunc( + "generateLinearInputWrapper", float(intercept), weights, xMean, + xVariance, int(nPoints), int(seed), float(eps))) + + @staticmethod + def generateLinearRDD(sc, nexamples, nfeatures, eps, + nParts=2, intercept=0.0): + """ + Generate a RDD of LabeledPoints. + """ + return callMLlibFunc( + "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), + float(eps), int(nParts), float(intercept)) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 1ecec5b126505..0a85da7443d3d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -396,6 +396,11 @@ def over(self, window): jc = self._jc.over(window._jspec) return Column(jc) + def __nonzero__(self): + raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " + "'~' for 'not' when building DataFrame boolean expressions.") + __bool__ = __nonzero__ + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13f4556943ac8..e6a434e4b2dff 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -164,6 +164,14 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) @@ -408,7 +416,7 @@ def test_column_operators(self): self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java index 611e02d8fb666..6a2356f1f9c6f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java @@ -155,27 +155,6 @@ public int fieldIndex(String name) { throw new UnsupportedOperationException(); } - /** - * A generic version of Row.equals(Row), which is used for tests. - */ - @Override - public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - int n = size(); - if (n != row.size()) { - return false; - } - for (int i = 0; i < n; i ++) { - if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) { - return false; - } - } - return true; - } - return false; - } - @Override public InternalRow copy() { final int n = size(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 8aaf5d7d89154..e99d5c87a44fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -365,36 +363,6 @@ trait Row extends Serializable { false } - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e3c2cc243310b..d7b537a9fe3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -26,7 +26,70 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow */ abstract class InternalRow extends Row { // A default implementation to change the return type - override def copy(): InternalRow = {this} + override def copy(): InternalRow = this + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) { + return false + } + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + while (i < length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6311784422a91..0a3f5a7b5cade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -192,49 +192,17 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). - */ - private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - g.bitmasks.foreach { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprs = ArrayBuffer.empty[Expression] - var bit = g.groupByExprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit) - bit -= 1 - } - - val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => - // if the input attribute in the Invalid Grouping Expression set of for this group - // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == g.gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += substitution - } - - result.toSeq - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a: Cube if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case a: Rollup if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case x: GroupingSets if x.resolved => + case a: Cube => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case a: Rollup => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case x: GroupingSets => + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() Aggregate( - x.groupByExprs :+ x.gid, + x.groupByExprs :+ VirtualColumn.groupingIdAttribute, x.aggregations, - Expand(expand(x), x.child.output :+ x.gid, x.child)) + Expand(x.bitmasks, x.groupByExprs, gid, x.child)) } } @@ -368,12 +336,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && - resolver(nameParts(0), VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics - q.asInstanceOf[GroupingAnalytics].gid + q transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bd5475d2066fc..47c5455435ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -175,8 +175,10 @@ class CodeGenContext { * Generate code for compare expression in Java */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // java boolean doesn't support > or < operator + case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" // use c1 - c2 may overflow - case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case other => s"$c1.compare($c2)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2e20eda1a3002..e362625469e29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case BinaryType => s"java.util.Arrays.hashCode($col)" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 58dbeaf89cad5..9cacdceb13837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E object VirtualColumn { val groupingIdName: String = "grouping__id" - def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 1098962ddc018..0d4c9ace5e124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { } } - // TODO(davies): add getDate and getDecimal - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - - var i = 0 - while (i < values.length) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - apply(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - - override def equals(o: Any): Boolean = o match { - case other: InternalRow => - if (values.length != other.length) { - return false - } - - var i = 0 - while (i < values.length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (apply(i) != other.apply(i)) { - return false - } - i += 1 - } - true - - case _ => false - } - override def copy(): InternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9132a786f77a7..98b4476076854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) + if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7814e51628db6..fae339808c233 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashSet case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -228,24 +229,76 @@ case class Window( /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for a input row. - * @param projections The group of expressions, all of the group expressions should - * output the same schema specified by the parameter `output` - * @param output The output Schema + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions * @param child Child operator */ case class Expand( - projections: Seq[Seq[Expression]], - output: Seq[Attribute], + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + val projections: Seq[Seq[Expression]] = expand() + + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] + + bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + + val substitution = (child.output :+ gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal.create(null, expr.dataType) + case x if x == gid => + // replace the groupingId with concrete value (the bit mask) + Literal.create(bitmask, IntegerType) + }) + + result += substitution + } + + result.toSeq + } + + override def output: Seq[Attribute] = { + child.output :+ gid + } } trait GroupingAnalytics extends UnaryNode { self: Product => - def gid: AttributeReference def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] @@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode { * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. - * The associated output will be one of the value in `bitmasks` */ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) @@ -290,15 +338,11 @@ case class GroupingSets( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) @@ -313,15 +357,11 @@ case class Cube( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = this.copy(aggregations = aggs) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 4bbbbe6c7f091..6c93698f8017b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} +import org.apache.spark.sql.types.Decimal class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) + test("MaxOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MaxOf(small, large), convert(2)) + checkEvaluation(MaxOf(large, small), convert(2)) + checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) + } + } - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + test("MaxOf for atomic type") { + checkEvaluation(MaxOf(true, false), true) + checkEvaluation(MaxOf("abc", "bcd"), "bcd") + checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 3.toByte)) } - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) + test("MinOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MinOf(small, large), convert(1)) + checkEvaluation(MinOf(large, small), convert(1)) + checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) + } + } - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + test("MinOf for atomic type") { + checkEvaluation(MinOf(true, false), false) + checkEvaluation(MinOf("abc", "bcd"), "abc") + checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 2.toByte)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 12d2da8b33986..158f54af13802 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -38,10 +38,23 @@ trait ExpressionEvalHelper { protected def checkEvaluation( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) - checkEvaluationWithGeneratedProjection(expression, expected, inputRow) - checkEvaluationWithOptimization(expression, expected, inputRow) + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expression, catalystValue, inputRow) + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte]. + */ + protected def checkResult(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case _ => result == expected + } } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { @@ -55,7 +68,7 @@ trait ExpressionEvalHelper { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -83,7 +96,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -109,7 +122,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + val expectedRow = new GenericRow(Array[Any](expected)) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f44f55dfb92d1..d924ff7a102f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,12 +18,26 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - // TODO: Add tests for all data types. + test("null") { + checkEvaluation(Literal.create(null, BooleanType), null) + checkEvaluation(Literal.create(null, ByteType), null) + checkEvaluation(Literal.create(null, ShortType), null) + checkEvaluation(Literal.create(null, IntegerType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, FloatType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, BinaryType), null) + checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) + checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) + } test("boolean literals") { checkEvaluation(Literal(true), true) @@ -31,25 +45,52 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("int literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(0L), 0L) + List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) + } + checkEvaluation(Literal(Long.MinValue), Long.MinValue) + checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { - List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toFloat), d.toFloat) - } + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) } + checkEvaluation(Literal(Double.MinValue), Double.MinValue) + checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal(Float.MinValue), Float.MinValue) + checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + } test("string literals") { + checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal("\0"), "\0") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) } + + test("binary literals") { + checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + } + + test("decimal") { + List(0.0, 1.2, 1.1111, 5).foreach { d => + checkEvaluation(Literal(Decimal(d)), Decimal(d)) + checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), + Decimal((d * 1000L).toLong, 10, 1)) + } + } + + // TODO(davies): add tests for ArrayType, MapType and StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index d363e631540d8..5dbb1d562c1d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLength(regEx), 5, create_row("abdef")) checkEvaluation(StringLength(regEx), 0, create_row("")) checkEvaluation(StringLength(regEx), null, create_row(null)) - // TODO currently bug in codegen, let's temporally disable this - // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) } - - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 492a3321bc0bc..f3f0f5305318e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1049,7 +1049,7 @@ class DataFrame private[sql]( * columns of the input row are implicitly joined with each value that is output by the function. * * {{{ - * df.explode("words", "word")(words: String => words.split(" ")) + * df.explode("words", "word"){words: String => words.split(" ")} * }}} * @group dfops * @since 1.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 16493c3d7c19c..265352647fa9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,6 +22,8 @@ import java.util.Properties import scala.collection.immutable import scala.collection.JavaConversions._ +import org.apache.parquet.hadoop.ParquetOutputCommitter + import org.apache.spark.sql.catalyst.CatalystConf private[spark] object SQLConf { @@ -252,9 +254,9 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", defaultValue = Some(false), - doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default" + - " because of a known bug in Paruet 1.6.0rc3 " + - "(PARQUET-136). However, " + + doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + + "because of a known bug in Parquet 1.6.0rc3 " + + "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + "if your table doesn't contain any nullable string or binary columns, it's still safe to " + "turn this feature on.") @@ -262,11 +264,21 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( + key = "spark.sql.parquet.output.committer.class", + defaultValue = Some(classOf[ParquetOutputCommitter].getName), + doc = "The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + + "option must be set in Hadoop Configuration. 2. This option overrides " + + "\"spark.sql.sources.outputCommitterClass\"." + ) + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "") - val HIVE_VERIFY_PARTITIONPATH = booleanConf("spark.sql.hive.verifyPartitionPath", + val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(true), doc = "") @@ -325,9 +337,13 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") - // The output committer class used by FSBasedRelation. The specified class needs to be a + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf` + // + // NOTE: + // + // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. + // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) @@ -415,7 +431,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) /** When true uses verifyPartitionPath to prune the path which is not exists. */ - private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITIONPATH) + private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c420eb9d761f..1ff1cc224de8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case logical.Expand(projections, output, child) => - execution.Expand(projections, output, planLater(child)) :: Nil + case e @ logical.Expand(_, _, _, child) => + execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Window(projectList, windowExpressions, spec, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index 62c4e92ebec68..1551afd7b7bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -17,19 +17,35 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.Log import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +/** + * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder + * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the + * destination folder. This can be useful for data stored in S3, where directory operations are + * relatively expensive. + * + * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" + * property via Hadoop [[Configuration]]. Not that this property overrides + * "spark.sql.sources.outputCommitterClass". + * + * *NOTE* + * + * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's + * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are + * left * empty). + */ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - override def getWorkPath(): Path = outputPath + override def getWorkPath: Path = outputPath override def abortTask(taskContext: TaskAttemptContext): Unit = {} override def commitTask(taskContext: TaskAttemptContext): Unit = {} override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true @@ -46,13 +62,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) try { ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } + } catch { case e: Exception => + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) } } } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e049d54bf55dc..1d353bd8e1114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -178,11 +178,11 @@ private[sql] class ParquetRelation2( val committerClass = conf.getClass( - "spark.sql.parquet.output.committer.class", + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - if (conf.get("spark.sql.parquet.output.committer.class") == null) { + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + classOf[ParquetOutputCommitter].getCanonicalName) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index c6f535dde7676..8b2a45d8e970a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -84,7 +84,7 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. - val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues.map(_._2)) + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) // Creates the StructType which represents the partition columns. val fields = { @@ -181,19 +181,18 @@ private[sql] object PartitioningUtils { * StringType * }}} */ - private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - // Column names of all partitions must match - val distinctPartitionsColNames = values.map(_.columnNames).distinct - - if (distinctPartitionsColNames.isEmpty) { + private[sql] def resolvePartitions( + pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { + if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { - assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") - s"Conflicting partition column names detected:\n$list" - }) + val distinctPartColNames = pathsWithPartitionValues.map(_._2.columnNames).distinct + assert( + distinctPartColNames.size == 1, + listConflictingPartitionColumns(pathsWithPartitionValues)) // Resolves possible type conflicts for each column + val values = pathsWithPartitionValues.map(_._2) val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => resolveTypeConflicts(values.map(_.literals(i))) @@ -206,6 +205,34 @@ private[sql] object PartitioningUtils { } } + private[sql] def listConflictingPartitionColumns( + pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { + val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct + + def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] = + seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }) + + val partColNamesToPaths = groupByKey(pathWithPartitionValues.map { + case (path, partValues) => partValues.columnNames -> path + }) + + val distinctPartColLists = distinctPartColNames.map(_.mkString(", ")).zipWithIndex.map { + case (names, index) => + s"Partition column name list #$index: $names" + } + + // Lists out those non-leaf partition directories that also contain files + val suspiciousPaths = distinctPartColNames.sortBy(_.length).flatMap(partColNamesToPaths) + + s"Conflicting partition column names detected:\n" + + distinctPartColLists.mkString("\n\t", "\n\t", "\n\n") + + "For partitioned table directories, data files should only live in leaf directories.\n" + + "And directories at the same level should have the same partition column name.\n" + + "Please check the following directories for unexpected files or " + + "inconsistent partition column names:\n" + + suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") + } + /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 01df189d1f3be..d0ebb11b063f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -538,4 +538,49 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 42c2d4c98ffb2..2f771d76793e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.client import java.io.{BufferedReader, InputStreamReader, File, PrintStream} import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -136,12 +137,62 @@ private[hive] class ClientWrapper( // TODO: should be a def?s // When we create this val client, the HiveConf of it (conf) is the one associated with state. - private val client = Hive.get(conf) + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = synchronized { + private def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5ae2dbb50d86b..e7c1779f80ce6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -64,6 +65,8 @@ private[client] sealed abstract class Shim { def getDriverResults(driver: Driver): Seq[String] + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def loadPartition( hive: Hive, loadPath: Path, @@ -192,6 +195,10 @@ private[client] class Shim_v0_12 extends Shim { res.toSeq } + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + } + override def loadPartition( hive: Hive, loadPath: Path, @@ -321,6 +328,12 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { JBoolean.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) override def loadPartition( hive: Hive, @@ -359,4 +372,10 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) } + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9871a70a40e69..9302b472925ed 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -17,10 +17,10 @@ package org.apache.spark.unsafe.types; +import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; -import javax.annotation.Nonnull; import org.apache.spark.unsafe.PlatformDependent; @@ -202,10 +202,6 @@ public int compare(final UTF8String other) { public boolean equals(final Object other) { if (other instanceof UTF8String) { return Arrays.equals(bytes, ((UTF8String) other).getBytes()); - } else if (other instanceof String) { - // Used only in unit tests. - String s = (String) other; - return bytes.length >= s.length() && length() == s.length() && toString().equals(s); } else { return false; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 80c179a1b5e75..796cdc9dbebdb 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -28,8 +28,6 @@ private void checkBasic(String str, int len) throws UnsupportedEncodingException Assert.assertEquals(UTF8String.fromString(str).length(), len); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len); - Assert.assertEquals(UTF8String.fromString(str), str); - Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), str); Assert.assertEquals(UTF8String.fromString(str).toString(), str); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str); Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));