MapType |
- enviroment |
+ environment |
list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
Note: The default value of valueContainsNull is True.
diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md
index c8ab146bcae0a..8d6e74370918f 100644
--- a/docs/streaming-flume-integration.md
+++ b/docs/streaming-flume-integration.md
@@ -99,6 +99,12 @@ Configuring Flume on the chosen machine requires the following two steps.
artifactId = scala-library
version = {{site.SCALA_VERSION}}
+ (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)).
+
+ groupId = org.apache.commons
+ artifactId = commons-lang3
+ version = 3.3.2
+
2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file.
agent.sinks = spark
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 1eb3b30332e4f..b784d59666fec 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1937,6 +1937,14 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre
unifiedStream.print();
{% endhighlight %}
+
+{% highlight python %}
+numStreams = 5
+kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)]
+unifiedStream = streamingContext.union(kafkaStreams)
+unifiedStream.print()
+{% endhighlight %}
+
Another parameter that should be considered is the receiver's blocking interval,
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index ec533d174ebdc..9df26ffca5775 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Create a model, and return it.
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
+
+ @Override
+ public MyJavaLogisticRegression copy(ParamMap extra) {
+ return defaultCopy(extra);
+ }
}
/**
diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py
new file mode 100644
index 0000000000000..dcd6a0fc6ff91
--- /dev/null
+++ b/examples/src/main/python/streaming/queue_stream.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Create a queue of RDDs that will be mapped/reduced one at a time in
+ 1 second intervals.
+
+ To run this example use
+ `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py
+"""
+import sys
+import time
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+
+ sc = SparkContext(appName="PythonStreamingQueueStream")
+ ssc = StreamingContext(sc, 1)
+
+ # Create the queue through which RDDs can be pushed to
+ # a QueueInputDStream
+ rddQueue = []
+ for i in xrange(5):
+ rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)]
+
+ # Create the QueueInputDStream and use it do some processing
+ inputStream = ssc.queueStream(rddQueue)
+ mappedStream = inputStream.map(lambda x: (x % 10, 1))
+ reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)
+ reducedStream.pprint()
+
+ ssc.start()
+ time.sleep(6)
+ ssc.stop(stopSparkContext=True, stopGraceFully=True)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 3ee456edbe01e..7b8cc21ed8982 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String)
// Create a model, and return it.
new MyLogisticRegressionModel(uid, weights).setParent(this)
}
+
+ override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
}
/**
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index c0669fb336657..3913b711ba28b 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -32,6 +32,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
@@ -65,8 +66,8 @@ public void tearDown() {
@Test
public void testKafkaStream() throws InterruptedException {
- String topic1 = "topic1";
- String topic2 = "topic2";
+ final String topic1 = "topic1";
+ final String topic2 = "topic2";
String[] topic1data = createTopicAndSendData(topic1);
String[] topic2data = createTopicAndSendData(topic2);
@@ -87,6 +88,16 @@ public void testKafkaStream() throws InterruptedException {
StringDecoder.class,
kafkaParams,
topicToSet(topic1)
+ ).transformToPair(
+ // Make sure you can get offset ranges from the rdd
+ new Function, JavaPairRDD>() {
+ @Override
+ public JavaPairRDD call(JavaPairRDD rdd) throws Exception {
+ OffsetRange[] offsets = ((HasOffsetRanges)rdd.rdd()).offsetRanges();
+ Assert.assertEquals(offsets[0].topic(), topic1);
+ return rdd;
+ }
+ }
).map(
new Function, String>() {
@Override
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index e9a5d7c0e7988..57e416591de69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
paramMaps.map(fit(dataset, _))
}
- override def copy(extra: ParamMap): Estimator[M] = {
- super.copy(extra).asInstanceOf[Estimator[M]]
- }
+ override def copy(extra: ParamMap): Estimator[M]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 186bf7ae7a2f6..252acc156583f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null
- override def copy(extra: ParamMap): M = {
- // The default implementation of Params.copy doesn't work for models.
- throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
- }
+ override def copy(extra: ParamMap): M
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a9bd28df71ee1..a1f3851d804ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -66,9 +66,7 @@ abstract class PipelineStage extends Params with Logging {
outputSchema
}
- override def copy(extra: ParamMap): PipelineStage = {
- super.copy(extra).asInstanceOf[PipelineStage]
- }
+ override def copy(extra: ParamMap): PipelineStage
}
/**
@@ -198,6 +196,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
- new PipelineModel(uid, stages)
+ new PipelineModel(uid, stages.map(_.copy(extra)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index e752b81a14282..edaa2afb790e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -90,9 +90,7 @@ abstract class Predictor[
copyValues(train(dataset).setParent(this))
}
- override def copy(extra: ParamMap): Learner = {
- super.copy(extra).asInstanceOf[Learner]
- }
+ override def copy(extra: ParamMap): Learner
/**
* Train a model using the given dataset and parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index f07f733a5ddb5..3c7bcf7590e6d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
*/
def transform(dataset: DataFrame): DataFrame
- override def copy(extra: ParamMap): Transformer = {
- super.copy(extra).asInstanceOf[Transformer]
- }
+ override def copy(extra: ParamMap): Transformer
}
/**
@@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
dataset.withColumn($(outputCol),
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}
+
+ override def copy(extra: ParamMap): T = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 263d580fe2dd3..14c285dbfc54a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8030e0728a56c..2dc1824964a42 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
}
+
+ override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62f4b51f770e9..554e3b8e052b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f136bcee9cf2b..2e6eedd45ab07 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)
new LogisticRegressionModel(uid, weights.compressed, intercept)
}
+
+ override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 825f9ed1b54b2..b657882f8ad3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -24,7 +24,7 @@ import scala.language.existentials
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
@@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] (
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}
+
+ override def copy(extra: ParamMap): OneVsRestModel = {
+ val copied = new OneVsRestModel(
+ uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
+ copyValues(copied, extra)
+ }
}
/**
@@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String)
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
copyValues(model)
}
+
+ override def copy(extra: ParamMap): OneVsRest = {
+ val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
+ if (isDefined(classifier)) {
+ copied.setClassifier($(classifier).copy(extra))
+ }
+ copied
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 852a67e066322..d3c67494a31e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index f695ddaeefc72..4a82b77f0edcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String)
metrics.unpersist()
metric
}
+
+ override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index 61e937e693699..e56c946a063e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
*/
def evaluate(dataset: DataFrame): Double
- override def copy(extra: ParamMap): Evaluator = {
- super.copy(extra).asInstanceOf[Evaluator]
- }
+ override def copy(extra: ParamMap): Evaluator
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index abb1b35bedea5..8670e9679d055 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.{Param, ParamValidators}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String)
}
metric
}
+
+ override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index b06122d733853..46314854d5e3a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -83,4 +83,6 @@ final class Binarizer(override val uid: String)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index a3d1f6f65ccaf..67e4785bc3553 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String)
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
+
+ override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
}
private[feature] object Bucketizer {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 1e758cb775de7..a359cb8f37ec3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.Param
+import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index f936aef80f8af..319d23e46cef4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
@@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
+
+ override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 376b84530cd57..ecde80810580c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/** @group getParam */
def getMinDocFreq: Int = $(minDocFreq)
- /** @group setParam */
- def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
-
/**
* Validate and transform the input schema.
*/
@@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
+
override def fit(dataset: DataFrame): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
@@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): IDF = defaultCopy(extra)
}
/**
@@ -109,4 +111,9 @@ class IDFModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): IDFModel = {
+ val copied = new IDFModel(uid, idfModel)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 8f34878c8d329..3825942795645 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
}
+
+ override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 442e95820217a..d85e468562d4a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT()
+
+ override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index b0fd06d84fdb3..ca3c1cfb56b7f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
}
/**
@@ -125,4 +127,9 @@ class StandardScalerModel private[ml] (
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}
+
+ override def copy(extra: ParamMap): StandardScalerModel = {
+ val copied = new StandardScalerModel(uid, scaler)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f4e250757560a..bf7be363b8224 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
}
/**
@@ -144,4 +146,9 @@ class StringIndexerModel private[ml] (
schema
}
}
+
+ override def copy(extra: ParamMap): StringIndexerModel = {
+ val copied = new StringIndexerModel(uid, labels)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 21c15b6c33f6c..5f9f57a2ebcfa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
+
+ override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
/**
@@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
+
+ override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 229ee27ec5942..9f83c2ee16178 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String)
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
}
+
+ override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}
private object VectorAssembler {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1d0f23b4fb3db..f4854a5e4b7b7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
@@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
SchemaUtils.appendColumn(schema, $(outputCol), dataType)
}
+
+ override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}
private object VectorIndexer {
@@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] (
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
newAttributeGroup.toStructField()
}
+
+ override def copy(extra: ParamMap): VectorIndexerModel = {
+ val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 36f19509f0cfb..6ea6590956300 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
}
/**
@@ -180,4 +182,9 @@ class Word2VecModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): Word2VecModel = {
+ val copied = new Word2VecModel(uid, wordVectors)
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ba94d6a3a80a9..15ebad8838a2a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable {
/**
* Creates a copy of this instance with the same UID and some extra params.
- * The default implementation tries to create a new instance with the same UID.
+ * Subclasses should implement this method and set the return type properly.
+ *
+ * @see [[defaultCopy()]]
+ */
+ def copy(extra: ParamMap): Params
+
+ /**
+ * Default implementation of copy with extra params.
+ * It tries to create a new instance with the same UID.
* Then it copies the embedded and extra parameters over and returns the new instance.
- * Subclasses should override this method if the default approach is not sufficient.
*/
- def copy(extra: ParamMap): Params = {
+ protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
- copyValues(that, extra)
+ copyValues(that, extra).asInstanceOf[T]
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index df009d855ecbb..2e44cd4cc6a22 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -216,6 +216,11 @@ class ALSModel private[ml] (
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
+
+ override def copy(extra: ParamMap): ALSModel = {
+ val copied = new ALSModel(uid, rank, userFactors, itemFactors)
+ copyValues(copied, extra)
+ }
}
@@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
+
+ override def copy(extra: ParamMap): ALS = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 43b68e7bb20fa..be1f8063d41d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0)
}
+
+ override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index b7e374bb6cb49..036e3acb07412 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 70cd8e9e87fae..01306545fc7cd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -186,6 +186,8 @@ class LinearRegression(override val uid: String)
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
}
+
+ override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 49a1f7ce8c995..21c59061a02fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
}
+
+ override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cb29392e8bc63..e2444ab65b43b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
est.copy(paramMap).validateParams()
}
}
+
+ override def copy(extra: ParamMap): CrossValidator = {
+ val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
+ if (copied.isDefined(estimator)) {
+ copied.setEstimator(copied.getEstimator.copy(extra))
+ }
+ if (copied.isDefined(evaluator)) {
+ copied.setEvaluator(copied.getEvaluator.copy(extra))
+ }
+ copied
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index efbfeb4059f5a..3fab7ea79befc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -159,7 +159,7 @@ private object IDF {
* Represents an IDF model that can transform term frequency vectors.
*/
@Experimental
-class IDFModel private[mllib] (val idf: Vector) extends Serializable {
+class IDFModel private[spark] (val idf: Vector) extends Serializable {
/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 51546d41c36a6..f087d06d2a46a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -431,7 +431,7 @@ class Word2Vec extends Serializable with Logging {
* Word2Vec model
*/
@Experimental
-class Word2VecModel private[mllib] (
+class Word2VecModel private[spark] (
model: Map[String, Array[Float]]) extends Serializable with Saveable {
// wordList: Ordered list of words obtained from model.
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index ff5929235ac2c..3ae09d39ef500 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -102,4 +102,9 @@ private void init() {
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
+
+ @Override
+ public JavaTestParams copy(ParamMap extra) {
+ return defaultCopy(extra);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 29394fefcbc43..63d2fa31c7499 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -24,6 +24,7 @@ import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
@@ -84,6 +85,15 @@ class PipelineSuite extends SparkFunSuite {
}
}
+ test("PipelineModel.copy") {
+ val hashingTF = new HashingTF()
+ .setNumFeatures(100)
+ val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
+ val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
+ require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+ "copy should handle extra stage params")
+ }
+
test("pipeline model constructors") {
val transform0 = mock[Transformer]
val model1 = mock[MyModel]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index ae40b0b8ff854..73b4805c4c597 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
- DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
}
+ test("params") {
+ ParamsSuite.checkParams(new DecisionTreeClassifier)
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ ParamsSuite.checkParams(model)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1302da3c373ff..82c345491bb3c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
+ test("params") {
+ ParamsSuite.checkParams(new GBTClassifier)
+ val model = new GBTClassificationModel("gbtc",
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(1.0))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a755cac3ea76e..5a6265ea992c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -18,8 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new LogisticRegression)
+ val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
test("logistic regression: default params") {
val lr = new LogisticRegression
assert(lr.getLabelCol === "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 1d04ccb509057..75cf5bd4ead4f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,15 +19,18 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
dataset = sqlContext.createDataFrame(rdd)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneVsRest)
+ val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
+ val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
+ ParamsSuite.checkParams(model)
+ }
+
test("one-vs-rest: default params") {
val numClasses = 3
val ova = new OneVsRest()
@@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
+
+ test("OneVsRest.copy and OneVsRestModel.copy") {
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+
+ val ovr = new OneVsRest()
+ withClue("copy with classifier unset should work") {
+ ovr.copy(ParamMap(lr.maxIter -> 10))
+ }
+ ovr.setClassifier(lr)
+ val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
+ require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
+ require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
+ "copy should handle extra classifier params")
+
+ val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+ ovrModel.models.foreach { case m: LogisticRegressionModel =>
+ require(m.getThreshold === 0.1, "copy should handle extra model params")
+ }
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index eee9355a67be3..1b6b69c7dc71e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestClassifier]].
*/
@@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
}
+ test("params") {
+ ParamsSuite.checkParams(new RandomForestClassifier)
+ val model = new RandomForestClassificationModel("rfc",
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000000000..def869fe66777
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new BinaryClassificationEvaluator)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 36a1ac6b7996d..aa722da323935 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -18,12 +18,17 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RegressionEvaluator)
+ }
+
test("Regression Evaluator: default params") {
/**
* Here is the instruction describing how to export the test data into CSV format
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 7953bd0417191..2086043983661 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
}
+ test("params") {
+ ParamsSuite.checkParams(new Binarizer)
+ }
+
test("Binarize continuous features with default parameter") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 507a8a7db24c7..ec85e0d151e07 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.feature
import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row}
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Bucketizer)
+ }
+
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 7b2d70e644005..4157b84b29d01 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -28,8 +28,7 @@ import org.apache.spark.util.Utils
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
- val hashingTF = new HashingTF
- ParamsSuite.checkParams(hashingTF, 3)
+ ParamsSuite.checkParams(new HashingTF)
}
test("hashingTF") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index d83772e8be755..08f80af03429b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("params") {
+ ParamsSuite.checkParams(new IDF)
+ val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
+ ParamsSuite.checkParams(model)
+ }
+
test("compute IDF with default parameter") {
val numOfFeatures = 4
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 2e5036a844562..65846a846b7b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
@@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
+ test("params") {
+ ParamsSuite.checkParams(new OneHotEncoder)
+ }
+
test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index feca866cd711d..29eebd8960ebc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
@@ -27,6 +28,10 @@ import org.apache.spark.sql.Row
class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new PolynomialExpansion)
+ }
+
test("Polynomial expansion with default parameter") {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5f557e16e5150..99f82bea42688 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -19,10 +19,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new StringIndexer)
+ val model = new StringIndexerModel("indexer", Array("a", "b"))
+ ParamsSuite.checkParams(model)
+ }
+
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index ac279cb3215c2..e5fd21c3f6fca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -20,15 +20,27 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
+class TokenizerSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new Tokenizer)
+ }
+}
+
class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
+ test("params") {
+ ParamsSuite.checkParams(new RegexTokenizer)
+ }
+
test("RegexTokenizer") {
val tokenizer0 = new RegexTokenizer()
.setGaps(false)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 489abb5af7130..bb4d5b983e0d4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
@@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col
class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new VectorAssembler)
+ }
+
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 06affc7305cf5..8c85c96d5c6d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
private def getIndexer: VectorIndexer =
new VectorIndexer().setInputCol("features").setOutputCol("indexed")
+ test("params") {
+ ParamsSuite.checkParams(new VectorIndexer)
+ val model = new VectorIndexerModel("indexer", 1, Map.empty)
+ ParamsSuite.checkParams(model)
+ }
+
test("Cannot fit an empty DataFrame") {
val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
val vectorIndexer = getIndexer
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 94ebc3aebfa37..aa6ce533fd885 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -18,13 +18,21 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new Word2Vec)
+ val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
+ ParamsSuite.checkParams(model)
+ }
+
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 96094d7a099aa..050d4170ea017 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite {
object ParamsSuite extends SparkFunSuite {
/**
- * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
- * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
- * the param method name.
+ * Checks common requirements for [[Params.params]]:
+ * - params are ordered by names
+ * - param parent has the same UID as the object's UID
+ * - param name is the same as the param method name
+ * - obj.copy should return the same type as the obj
*/
- def checkParams(obj: Params, expectedNumParams: Int): Unit = {
+ def checkParams(obj: Params): Unit = {
+ val clazz = obj.getClass
+
val params = obj.params
- require(params.length === expectedNumParams,
- s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
val paramNames = params.map(_.name)
- require(paramNames === paramNames.sorted)
+ require(paramNames === paramNames.sorted, "params must be ordered by names")
params.foreach { p =>
assert(p.parent === obj.uid)
assert(obj.getParam(p.name) === p)
+ // TODO: Check that setters return self, which needs special handling for generic types.
}
+
+ val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
+ val copyReturnType = copyMethod.getReturnType
+ require(copyReturnType === obj.getClass,
+ s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index a9e78366ad98f..2759248344531 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H
require(isDefined(inputCol))
}
- override def copy(extra: ParamMap): TestParams = {
- super.copy(extra).asInstanceOf[TestParams]
- }
+ override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index eb5408d3fee7c..b3af81a3c60b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.param.shared
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.param.{ParamMap, Params}
class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
- class Obj(override val uid: String) extends Params with HasOutputCol
+ class Obj(override val uid: String) extends Params with HasOutputCol {
+ override def copy(extra: ParamMap): Obj = defaultCopy(extra)
+ }
val obj = new Obj("obj")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 9b3619f0046ea..36af4b34a9e40 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
-
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
@@ -98,6 +97,8 @@ object CrossValidatorSuite {
override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException
}
+
+ override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
}
class MyEvaluator extends Evaluator {
@@ -107,5 +108,7 @@ object CrossValidatorSuite {
}
override val uid: String = "eval"
+
+ override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
}
}
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d8cdcda3a3783..7f9d0a338d31e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream):
if size < best:
batch *= 2
elif size > best * 10 and batch > 1:
- batch /= 2
+ batch //= 2
def __repr__(self):
return "AutoBatchedSerializer(%s)" % self.serializer
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 81c420ce16541..67752c0d150b9 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False):
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
- batch, limit = 100, self.memory_limit
+ batch, limit = 100, self._next_limit()
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
@@ -512,9 +512,6 @@ def load(f):
f.close()
chunks.append(load(open(path, 'rb')))
current_chunk = []
- gc.collect()
- batch //= 2
- limit = self._next_limit()
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
DiskBytesSpilled += os.path.getsize(path)
os.unlink(path) # data will be deleted after close
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bbf465aca8d4d..acdb01d3d3f5f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -18,6 +18,7 @@
"""
A collections of builtin functions
"""
+import math
import sys
if sys.version < "3":
@@ -143,7 +144,7 @@ def _():
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
- 'pow': 'Returns the value of the first argument raised to the power of the second argument.'
+ 'pow': 'Returns the value of the first argument raised to the power of the second argument.',
}
_window_functions = {
@@ -403,6 +404,26 @@ def when(condition, value):
return Column(jc)
+@since(1.5)
+def log(arg1, arg2=None):
+ """Returns the first argument-based logarithm of the second argument.
+
+ If there is only one argument, then this takes the natural logarithm of the argument.
+
+ >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect()
+ ['0.30102', '0.69897']
+
+ >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect()
+ ['0.69314', '1.60943']
+ """
+ sc = SparkContext._active_spark_context
+ if arg2 is None:
+ jc = sc._jvm.functions.log(_to_java_column(arg1))
+ else:
+ jc = sc._jvm.functions.log(arg1, _to_java_column(arg2))
+ return Column(jc)
+
+
@since(1.4)
def lag(col, count=1, default=None):
"""
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index ff097985fae3e..8dcb9645cdc6b 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -176,7 +176,7 @@ def takeAndPrint(time, rdd):
print(record)
if len(taken) > num:
print("...")
- print()
+ print("")
self.foreachRDD(takeAndPrint)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 11b402e6df6c1..78265423682b0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -179,9 +179,12 @@ def test_in_memory_sort(self):
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
def test_external_sort(self):
+ class CustomizedSorter(ExternalSorter):
+ def _next_limit(self):
+ return self.memory_limit
l = list(range(1024))
random.shuffle(l)
- sorter = ExternalSorter(1)
+ sorter = CustomizedSorter(1)
self.assertEqual(sorted(l), list(sorter.sorted(l)))
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 97b123ec2f6d9..13b2bb05f5280 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -112,6 +112,7 @@ object FunctionRegistry {
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Hypot]("hypot"),
+ expression[Logarithm]("log"),
expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 189451d0d9ad7..8012b224eb444 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -307,7 +307,6 @@ trait HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
- case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 167e460d5a93e..ace8427c8ddaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
protected override def evalInternal(evalE: Any) = evalE
}
-case class Sqrt(child: Expression) extends UnaryArithmetic {
- override def dataType: DataType = DoubleType
- override def nullable: Boolean = true
- override def toString: String = s"SQRT($child)"
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
-
- private lazy val numeric = TypeUtils.getNumeric(child.dataType)
-
- protected override def evalInternal(evalE: Any) = {
- val value = numeric.toDouble(evalE)
- if (value < 0) null
- else math.sqrt(value)
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- if (${eval.primitive} < 0.0) {
- ${ev.isNull} = true;
- } else {
- ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
- }
- }
- """
- }
-}
-
/**
* A function that get the absolute value of the numeric value.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 1aaf9b309efc3..72fdcebb4cbc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -53,7 +53,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
* Returns a Row containing the evaluation of all children expressions.
* TODO: [[CreateStruct]] does not support codegen.
*/
-case class CreateStruct(children: Seq[NamedExpression]) extends Expression {
+case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -62,9 +62,14 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression {
override lazy val dataType: StructType = {
assert(resolved,
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
- val fields = children.map { child =>
- StructField(child.name, child.dataType, child.nullable, child.metadata)
- }
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
+ }
+ }
StructType(fields)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 42c596b5b31ab..f79bf4aee00d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
@@ -255,3 +257,27 @@ case class Pow(left: Expression, right: Expression)
"""
}
}
+
+case class Logarithm(left: Expression, right: Expression)
+ extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
+
+ /**
+ * Natural log, i.e. using e as the base.
+ */
+ def this(child: Expression) = {
+ this(EulerNumber(), child)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val logCode = if (left.isInstanceOf[EulerNumber]) {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
+ } else {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
+ }
+ logCode + s"""
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ """
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 3f4843259e80b..4bbbbe6c7f091 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
}
-
- test("SQRT") {
- val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
- val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
- val rowSequence = inputSequence.map(l => create_row(l.toDouble))
- val d = 'a.double.at(0)
-
- for ((row, expected) <- rowSequence zip expectedResults) {
- checkEvaluation(Sqrt(d), expected, row)
- }
-
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
- checkEvaluation(Sqrt(-1), null, EmptyRow)
- checkEvaluation(Sqrt(-1.5), null, EmptyRow)
- }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
index dcb3635c5ccae..49b111989799b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
@@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
- assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt
- assertError(Sqrt('booleanField), "function sqrt accepts numeric type")
assertError(Abs('stringField), "function abs accepts numeric type")
assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 864c954ee82cb..21e9b92b7214e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.DoubleType
@@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
}
+ test("sqrt") {
+ testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
+ testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
+
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
+ checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
+ checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
+ }
+
test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
@@ -204,4 +214,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testBinary(Atan2, math.atan2)
}
+ test("binary log") {
+ val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
+ val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
+
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
+ checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
+ checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
+ }
+ checkEvaluation(
+ Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
+ null,
+ create_row(null))
+ checkEvaluation(
+ Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
+ null,
+ create_row(null))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d3efa83380d04..b4e008a6e8480 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -621,7 +621,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @since 1.3.0
*/
@scala.annotation.varargs
- def in(list: Column*): Column = In(expr, list.map(_.expr))
+ def in(list: Any*): Column = In(expr, list.map(lit(_).expr))
/**
* SQL like expression.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 55ab6b3358e3c..16493c3d7c19c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -25,74 +25,333 @@ import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.CatalystConf
private[spark] object SQLConf {
- val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
- val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
- val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
- val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
- val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
- val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
- val CODEGEN_ENABLED = "spark.sql.codegen"
- val UNSAFE_ENABLED = "spark.sql.unsafe.enabled"
- val DIALECT = "spark.sql.dialect"
- val CASE_SENSITIVE = "spark.sql.caseSensitive"
-
- val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
- val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp"
- val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata"
- val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec"
- val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown"
- val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi"
-
- val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown"
-
- val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath"
-
- val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord"
- val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout"
+
+ private val sqlConfEntries = java.util.Collections.synchronizedMap(
+ new java.util.HashMap[String, SQLConfEntry[_]]())
+
+ /**
+ * An entry contains all meta information for a configuration.
+ *
+ * @param key the key for the configuration
+ * @param defaultValue the default value for the configuration
+ * @param valueConverter how to convert a string to the value. It should throw an exception if the
+ * string does not have the required format.
+ * @param stringConverter how to convert a value to a string that the user can use it as a valid
+ * string value. It's usually `toString`. But sometimes, a custom converter
+ * is necessary. E.g., if T is List[String], `a, b, c` is better than
+ * `List(a, b, c)`.
+ * @param doc the document for the configuration
+ * @param isPublic if this configuration is public to the user. If it's `false`, this
+ * configuration is only used internally and we should not expose it to the user.
+ * @tparam T the value type
+ */
+ private[sql] class SQLConfEntry[T] private(
+ val key: String,
+ val defaultValue: Option[T],
+ val valueConverter: String => T,
+ val stringConverter: T => String,
+ val doc: String,
+ val isPublic: Boolean) {
+
+ def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("")
+
+ override def toString: String = {
+ s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)"
+ }
+ }
+
+ private[sql] object SQLConfEntry {
+
+ private def apply[T](
+ key: String,
+ defaultValue: Option[T],
+ valueConverter: String => T,
+ stringConverter: T => String,
+ doc: String,
+ isPublic: Boolean): SQLConfEntry[T] =
+ sqlConfEntries.synchronized {
+ if (sqlConfEntries.containsKey(key)) {
+ throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered")
+ }
+ val entry =
+ new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic)
+ sqlConfEntries.put(key, entry)
+ entry
+ }
+
+ def intConf(
+ key: String,
+ defaultValue: Option[Int] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Int] =
+ SQLConfEntry(key, defaultValue, { v =>
+ try {
+ v.toInt
+ } catch {
+ case _: NumberFormatException =>
+ throw new IllegalArgumentException(s"$key should be int, but was $v")
+ }
+ }, _.toString, doc, isPublic)
+
+ def longConf(
+ key: String,
+ defaultValue: Option[Long] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Long] =
+ SQLConfEntry(key, defaultValue, { v =>
+ try {
+ v.toLong
+ } catch {
+ case _: NumberFormatException =>
+ throw new IllegalArgumentException(s"$key should be long, but was $v")
+ }
+ }, _.toString, doc, isPublic)
+
+ def doubleConf(
+ key: String,
+ defaultValue: Option[Double] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Double] =
+ SQLConfEntry(key, defaultValue, { v =>
+ try {
+ v.toDouble
+ } catch {
+ case _: NumberFormatException =>
+ throw new IllegalArgumentException(s"$key should be double, but was $v")
+ }
+ }, _.toString, doc, isPublic)
+
+ def booleanConf(
+ key: String,
+ defaultValue: Option[Boolean] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Boolean] =
+ SQLConfEntry(key, defaultValue, { v =>
+ try {
+ v.toBoolean
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new IllegalArgumentException(s"$key should be boolean, but was $v")
+ }
+ }, _.toString, doc, isPublic)
+
+ def stringConf(
+ key: String,
+ defaultValue: Option[String] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[String] =
+ SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic)
+
+ def enumConf[T](
+ key: String,
+ valueConverter: String => T,
+ validValues: Set[T],
+ defaultValue: Option[T] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[T] =
+ SQLConfEntry(key, defaultValue, v => {
+ val _v = valueConverter(v)
+ if (!validValues.contains(_v)) {
+ throw new IllegalArgumentException(
+ s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v")
+ }
+ _v
+ }, _.toString, doc, isPublic)
+
+ def seqConf[T](
+ key: String,
+ valueConverter: String => T,
+ defaultValue: Option[Seq[T]] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Seq[T]] = {
+ SQLConfEntry(
+ key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic)
+ }
+
+ def stringSeqConf(
+ key: String,
+ defaultValue: Option[Seq[String]] = None,
+ doc: String = "",
+ isPublic: Boolean = true): SQLConfEntry[Seq[String]] = {
+ seqConf(key, s => s, defaultValue, doc, isPublic)
+ }
+ }
+
+ import SQLConfEntry._
+
+ val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed",
+ defaultValue = Some(true),
+ doc = "When set to true Spark SQL will automatically select a compression codec for each " +
+ "column based on statistics of the data.")
+
+ val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize",
+ defaultValue = Some(10000),
+ doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " +
+ "memory utilization and compression, but risk OOMs when caching data.")
+
+ val IN_MEMORY_PARTITION_PRUNING =
+ booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning",
+ defaultValue = Some(false),
+ doc = "")
+
+ val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold",
+ defaultValue = Some(10 * 1024 * 1024),
+ doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " +
+ "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " +
+ "Note that currently statistics are only supported for Hive Metastore tables where the " +
+ "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.")
+
+ val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false)
+
+ val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions",
+ defaultValue = Some(200),
+ doc = "Configures the number of partitions to use when shuffling data for joins or " +
+ "aggregations.")
+
+ val CODEGEN_ENABLED = booleanConf("spark.sql.codegen",
+ defaultValue = Some(true),
+ doc = "When true, code will be dynamically generated at runtime for expression evaluation in" +
+ " a specific query. For some queries with complicated expression this option can lead to " +
+ "significant speed-ups. However, for simple queries this can actually slow down query " +
+ "execution.")
+
+ val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled",
+ defaultValue = Some(false),
+ doc = "")
+
+ val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "")
+
+ val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive",
+ defaultValue = Some(true),
+ doc = "")
+
+ val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString",
+ defaultValue = Some(false),
+ doc = "Some other Parquet-producing systems, in particular Impala and older versions of " +
+ "Spark SQL, do not differentiate between binary data and strings when writing out the " +
+ "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " +
+ "compatibility with these systems.")
+
+ val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp",
+ defaultValue = Some(true),
+ doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " +
+ "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " +
+ "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " +
+ "provide compatibility with these systems.")
+
+ val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata",
+ defaultValue = Some(true),
+ doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.")
+
+ val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec",
+ valueConverter = v => v.toLowerCase,
+ validValues = Set("uncompressed", "snappy", "gzip", "lzo"),
+ defaultValue = Some("gzip"),
+ doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " +
+ "uncompressed, snappy, gzip, lzo.")
+
+ val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown",
+ defaultValue = Some(false),
+ doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default" +
+ " because of a known bug in Paruet 1.6.0rc3 " +
+ "(PARQUET-136). However, " +
+ "if your table doesn't contain any nullable string or binary columns, it's still safe to " +
+ "turn this feature on.")
+
+ val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi",
+ defaultValue = Some(true),
+ doc = "")
+
+ val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
+ defaultValue = Some(false),
+ doc = "")
+
+ val HIVE_VERIFY_PARTITIONPATH = booleanConf("spark.sql.hive.verifyPartitionPath",
+ defaultValue = Some(true),
+ doc = "")
+
+ val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord",
+ defaultValue = Some("_corrupt_record"),
+ doc = "")
+
+ val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout",
+ defaultValue = Some(5 * 60),
+ doc = "")
// Options that control which operators can be chosen by the query planner. These should be
// considered hints and may be ignored by future versions of Spark SQL.
- val EXTERNAL_SORT = "spark.sql.planner.externalSort"
- val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
+ val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort",
+ defaultValue = Some(true),
+ doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" +
+ " memory.")
+
+ val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin",
+ defaultValue = Some(false),
+ doc = "")
// This is only used for the thriftserver
- val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
- val THRIFTSERVER_UI_STATEMENT_LIMIT = "spark.sql.thriftserver.ui.retainedStatements"
- val THRIFTSERVER_UI_SESSION_LIMIT = "spark.sql.thriftserver.ui.retainedSessions"
+ val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool",
+ doc = "Set a Fair Scheduler pool for a JDBC client session")
+
+ val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements",
+ defaultValue = Some(200),
+ doc = "")
+
+ val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions",
+ defaultValue = Some(200),
+ doc = "")
// This is used to set the default data source
- val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
+ val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default",
+ defaultValue = Some("org.apache.spark.sql.parquet"),
+ doc = "")
+
// This is used to control the when we will split a schema's JSON string to multiple pieces
// in order to fit the JSON string in metastore's table property (by default, the value has
// a length restriction of 4000 characters). We will split the JSON string of a schema
// to its length exceeds the threshold.
- val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold"
+ val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold",
+ defaultValue = Some(4000),
+ doc = "")
// Whether to perform partition discovery when loading external data sources. Default to true.
- val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled"
+ val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled",
+ defaultValue = Some(true),
+ doc = "")
// Whether to perform partition column type inference. Default to true.
- val PARTITION_COLUMN_TYPE_INFERENCE = "spark.sql.sources.partitionColumnTypeInference.enabled"
+ val PARTITION_COLUMN_TYPE_INFERENCE =
+ booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled",
+ defaultValue = Some(true),
+ doc = "")
// The output committer class used by FSBasedRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
// NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf`
- val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass"
+ val OUTPUT_COMMITTER_CLASS =
+ stringConf("spark.sql.sources.outputCommitterClass", isPublic = false)
// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
- val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
+ val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis",
+ defaultValue = Some(true),
+ doc = "")
// Whether to automatically resolve ambiguity in join conditions for self-joins.
// See SPARK-6231.
- val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"
+ val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY =
+ booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "")
// Whether to retain group by columns or not in GroupedData.agg.
- val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"
+ val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns",
+ defaultValue = Some(true),
+ doc = "")
- val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
+ val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2",
+ defaultValue = Some(true), doc = "")
- val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
+ val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI",
+ defaultValue = Some(true), doc = "")
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -131,56 +390,54 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
* Note that the choice of dialect does not affect things like what tables are available or
* how query execution is performed.
*/
- private[spark] def dialect: String = getConf(DIALECT, "sql")
+ private[spark] def dialect: String = getConf(DIALECT)
/** When true tables cached using the in-memory columnar caching will be compressed. */
- private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean
+ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED)
/** The compression codec for writing to a Parquetfile */
- private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip")
+ private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
+
+ private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA)
/** The number of rows that will be */
- private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt
+ private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE)
/** Number of partitions to use for shuffle operators. */
- private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt
+ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
/** When true predicates will be passed to the parquet record reader when possible. */
- private[spark] def parquetFilterPushDown =
- getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean
+ private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
/** When true uses Parquet implementation based on data source API */
- private[spark] def parquetUseDataSourceApi =
- getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean
+ private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API)
- private[spark] def orcFilterPushDown =
- getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean
+ private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
/** When true uses verifyPartitionPath to prune the path which is not exists. */
- private[spark] def verifyPartitionPath =
- getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean
+ private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITIONPATH)
/** When true the planner will use the external sort, which may spill to disk. */
- private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "true").toBoolean
+ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT)
/**
* Sort merge join would sort the two side of join first, and then iterate both sides together
* only once to get all matches. Using sort merge join can save a lot of memory usage compared
* to HashJoin.
*/
- private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
+ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN)
/**
* When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
* than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation.
*/
- private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "true").toBoolean
+ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED)
/**
* caseSensitive analysis true by default
*/
- def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, "true").toBoolean
+ def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
/**
* When set to true, Spark SQL will use managed memory for certain operations. This option only
@@ -188,15 +445,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*
* Defaults to false as this feature is currently experimental.
*/
- private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean
+ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
- private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
+ private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
/**
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
*/
- private[spark] def useJacksonStreamingAPI: Boolean =
- getConf(USE_JACKSON_STREAMING_API, "true").toBoolean
+ private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API)
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
@@ -205,8 +461,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*
* Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000.
*/
- private[spark] def autoBroadcastJoinThreshold: Int =
- getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt
+ private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
/**
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
@@ -215,82 +470,116 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
* in joins.
*/
private[spark] def defaultSizeInBytes: Long =
- getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong
+ getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)
/**
* When set to true, we always treat byte arrays in Parquet files as strings.
*/
- private[spark] def isParquetBinaryAsString: Boolean =
- getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean
+ private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING)
/**
* When set to true, we always treat INT96Values in Parquet files as timestamp.
*/
- private[spark] def isParquetINT96AsTimestamp: Boolean =
- getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean
+ private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP)
/**
* When set to true, partition pruning for in-memory columnar tables is enabled.
*/
- private[spark] def inMemoryPartitionPruning: Boolean =
- getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+ private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING)
- private[spark] def columnNameOfCorruptRecord: String =
- getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record")
+ private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD)
/**
* Timeout in seconds for the broadcast wait time in hash join
*/
- private[spark] def broadcastTimeout: Int =
- getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt
+ private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT)
- private[spark] def defaultDataSourceName: String =
- getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet")
+ private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME)
- private[spark] def partitionDiscoveryEnabled() =
- getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean
+ private[spark] def partitionDiscoveryEnabled(): Boolean =
+ getConf(SQLConf.PARTITION_DISCOVERY_ENABLED)
- private[spark] def partitionColumnTypeInferenceEnabled() =
- getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE, "true").toBoolean
+ private[spark] def partitionColumnTypeInferenceEnabled(): Boolean =
+ getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE)
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
- private[spark] def schemaStringLengthThreshold: Int =
- getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt
+ private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
- private[spark] def dataFrameEagerAnalysis: Boolean =
- getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean
+ private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS)
private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
- getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean
+ getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
- private[spark] def dataFrameRetainGroupColumns: Boolean =
- getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
+ private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
def setConf(props: Properties): Unit = settings.synchronized {
- props.foreach { case (k, v) => settings.put(k, v) }
+ props.foreach { case (k, v) => setConfString(k, v) }
}
- /** Set the given Spark SQL configuration property. */
- def setConf(key: String, value: String): Unit = {
+ /** Set the given Spark SQL configuration property using a `string` value. */
+ def setConfString(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
require(value != null, s"value cannot be null for key: $key")
+ val entry = sqlConfEntries.get(key)
+ if (entry != null) {
+ // Only verify configs in the SQLConf object
+ entry.valueConverter(value)
+ }
settings.put(key, value)
}
+ /** Set the given Spark SQL configuration property. */
+ def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
+ require(entry != null, "entry cannot be null")
+ require(value != null, s"value cannot be null for key: ${entry.key}")
+ require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
+ settings.put(entry.key, entry.stringConverter(value))
+ }
+
/** Return the value of Spark SQL configuration property for the given key. */
- def getConf(key: String): String = {
- Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key))
+ def getConfString(key: String): String = {
+ Option(settings.get(key)).
+ orElse {
+ // Try to use the default value
+ Option(sqlConfEntries.get(key)).map(_.defaultValueString)
+ }.
+ getOrElse(throw new NoSuchElementException(key))
+ }
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the
+ * desired one.
+ */
+ def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = {
+ require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
+ Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue)
}
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
- * yet, return `defaultValue`.
+ * yet, return `defaultValue` in [[SQLConfEntry]].
+ */
+ def getConf[T](entry: SQLConfEntry[T]): T = {
+ require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
+ Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue).
+ getOrElse(throw new NoSuchElementException(entry.key))
+ }
+
+ /**
+ * Return the `string` value of Spark SQL configuration property for the given key. If the key is
+ * not set yet, return `defaultValue`.
*/
- def getConf(key: String, defaultValue: String): String = {
+ def getConfString(key: String, defaultValue: String): String = {
+ val entry = sqlConfEntries.get(key)
+ if (entry != null && defaultValue != "") {
+ // Only verify configs in the SQLConf object
+ entry.valueConverter(defaultValue)
+ }
Option(settings.get(key)).getOrElse(defaultValue)
}
@@ -300,11 +589,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*/
def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap }
- private[spark] def unsetConf(key: String) {
+ /**
+ * Return all the configuration definitions that have been defined in [[SQLConf]]. Each
+ * definition contains key, defaultValue and doc.
+ */
+ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized {
+ sqlConfEntries.values.filter(_.isPublic).map { entry =>
+ (entry.key, entry.defaultValueString, entry.doc)
+ }.toSeq
+ }
+
+ private[spark] def unsetConf(key: String): Unit = {
settings -= key
}
- private[spark] def clear() {
+ private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = {
+ settings -= entry.key
+ }
+
+ private[spark] def clear(): Unit = {
settings.clear()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 6b605f7130167..04fc798bf3738 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,6 +31,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
@@ -79,13 +80,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def setConf(props: Properties): Unit = conf.setConf(props)
+ /** Set the given Spark SQL configuration property. */
+ private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value)
+
/**
* Set the given Spark SQL configuration property.
*
* @group config
* @since 1.0.0
*/
- def setConf(key: String, value: String): Unit = conf.setConf(key, value)
+ def setConf(key: String, value: String): Unit = conf.setConfString(key, value)
/**
* Return the value of Spark SQL configuration property for the given key.
@@ -93,7 +97,22 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group config
* @since 1.0.0
*/
- def getConf(key: String): String = conf.getConf(key)
+ def getConf(key: String): String = conf.getConfString(key)
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue` in [[SQLConfEntry]].
+ */
+ private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry)
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the
+ * desired one.
+ */
+ private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = {
+ conf.getConf(entry, defaultValue)
+ }
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
@@ -102,7 +121,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group config
* @since 1.0.0
*/
- def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue)
+ def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue)
/**
* Return all the configuration properties that have been set (i.e. not the default).
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
index 305b306a79871..e59fa6e162900 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
@@ -44,8 +44,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
private val pair: Parser[LogicalPlan] =
(key ~ ("=".r ~> value).?).? ^^ {
- case None => SetCommand(None, output)
- case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output)
+ case None => SetCommand(None)
+ case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
}
def apply(input: String): LogicalPlan = parseAll(pair, input) match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index c9dfcea5d051e..5e9951f248ff2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.NoSuchElementException
+
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
@@ -75,48 +77,92 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan
* :: DeveloperApi ::
*/
@DeveloperApi
-case class SetCommand(
- kv: Option[(String, Option[String])],
- override val output: Seq[Attribute])
- extends RunnableCommand with Logging {
+case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging {
+
+ private def keyValueOutput: Seq[Attribute] = {
+ val schema = StructType(
+ StructField("key", StringType, false) ::
+ StructField("value", StringType, false) :: Nil)
+ schema.toAttributes
+ }
- override def run(sqlContext: SQLContext): Seq[Row] = kv match {
+ private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match {
// Configures the deprecated "mapred.reduce.tasks" property.
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) =>
- logWarning(
- s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
- s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- if (value.toInt < 1) {
- val msg = s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " +
- "determining the number of reducers is not supported."
- throw new IllegalArgumentException(msg)
- } else {
- sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
+ val runFunc = (sqlContext: SQLContext) => {
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
+ if (value.toInt < 1) {
+ val msg =
+ s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " +
+ "determining the number of reducers is not supported."
+ throw new IllegalArgumentException(msg)
+ } else {
+ sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value)
+ Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value))
+ }
}
+ (keyValueOutput, runFunc)
// Configures a single property.
case Some((key, Some(value))) =>
- sqlContext.setConf(key, value)
- Seq(Row(s"$key=$value"))
+ val runFunc = (sqlContext: SQLContext) => {
+ sqlContext.setConf(key, value)
+ Seq(Row(key, value))
+ }
+ (keyValueOutput, runFunc)
- // Queries all key-value pairs that are set in the SQLConf of the sqlContext.
- // Notice that different from Hive, here "SET -v" is an alias of "SET".
// (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.)
- case Some(("-v", None)) | None =>
- sqlContext.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq
+ // Queries all key-value pairs that are set in the SQLConf of the sqlContext.
+ case None =>
+ val runFunc = (sqlContext: SQLContext) => {
+ sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq
+ }
+ (keyValueOutput, runFunc)
+
+ // Queries all properties along with their default values and docs that are defined in the
+ // SQLConf of the sqlContext.
+ case Some(("-v", None)) =>
+ val runFunc = (sqlContext: SQLContext) => {
+ sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) =>
+ Row(key, defaultValue, doc)
+ }
+ }
+ val schema = StructType(
+ StructField("key", StringType, false) ::
+ StructField("default", StringType, false) ::
+ StructField("meaning", StringType, false) :: Nil)
+ (schema.toAttributes, runFunc)
// Queries the deprecated "mapred.reduce.tasks" property.
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) =>
- logWarning(
- s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
- s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}"))
+ val runFunc = (sqlContext: SQLContext) => {
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
+ Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString))
+ }
+ (keyValueOutput, runFunc)
// Queries a single property.
case Some((key, None)) =>
- Seq(Row(s"$key=${sqlContext.getConf(key, "")}"))
+ val runFunc = (sqlContext: SQLContext) => {
+ val value =
+ try {
+ sqlContext.getConf(key)
+ } catch {
+ case _: NoSuchElementException => ""
+ }
+ Seq(Row(key, value))
+ }
+ (keyValueOutput, runFunc)
}
+
+ override val output: Seq[Attribute] = _output
+
+ override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext)
+
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 3ee4033baee2e..2964edac1aba2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -48,7 +48,7 @@ package object debug {
*/
implicit class DebugSQLContext(sqlContext: SQLContext) {
def debug(): Unit = {
- sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
+ sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c5b77724aae17..d8a91bead7c33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -707,11 +707,19 @@ object functions {
/**
* Computes the square root of the specified float value.
*
- * @group normal_funcs
+ * @group math_funcs
* @since 1.3.0
*/
def sqrt(e: Column): Column = Sqrt(e.expr)
+ /**
+ * Computes the square root of the specified float value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def sqrt(colName: String): Column = sqrt(Column(colName))
+
/**
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
* a derived column expression that is named (i.e. aliased).
@@ -1083,6 +1091,22 @@ object functions {
*/
def log(columnName: String): Column = log(Column(columnName))
+ /**
+ * Returns the first argument-base logarithm of the second argument.
+ *
+ * @group math_funcs
+ * @since 1.4.0
+ */
+ def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr)
+
+ /**
+ * Returns the first argument-base logarithm of the second argument.
+ *
+ * @group math_funcs
+ * @since 1.4.0
+ */
+ def log(base: Double, columnName: String): Column = log(base, Column(columnName))
+
/**
* Computes the logarithm of the given value in base 10.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 39360e13313a3..b30fc171c0af1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -49,7 +49,8 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.{Logging, SerializableWritable, TaskContext}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.util.SerializableConfiguration
/**
* :: DeveloperApi ::
@@ -113,12 +114,12 @@ private[sql] case class ParquetTableScan(
.foreach(ParquetInputFormat.setFilterPredicate(conf, _))
// Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata
- conf.set(
- SQLConf.PARQUET_CACHE_METADATA,
- sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true"))
+ conf.setBoolean(
+ SQLConf.PARQUET_CACHE_METADATA.key,
+ sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true))
// Use task side metadata in parquet
- conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true);
+ conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true)
val baseRDD =
new org.apache.spark.rdd.NewHadoopRDD(
@@ -329,7 +330,7 @@ private[sql] case class InsertIntoParquetTable(
job.setOutputKeyClass(keyType)
job.setOutputValueClass(classOf[InternalRow])
NewFileOutputFormat.setOutputPath(job, new Path(path))
- val wrappedConf = new SerializableWritable(job.getConfiguration)
+ val wrappedConf = new SerializableConfiguration(job.getConfiguration)
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = sqlContext.sparkContext.newRddId()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index bba6f1ec96aa8..c9de45e0ddfbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConversions._
import scala.util.Try
import com.google.common.base.Objects
-import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
@@ -42,8 +41,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SerializableWritable, SparkException, Partition => SparkPartition}
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.{Logging, SparkException, Partition => SparkPartition}
private[sql] class DefaultSource extends HadoopFsRelationProvider {
override def createRelation(
@@ -220,7 +219,7 @@ private[sql] class ParquetRelation2(
}
conf.setClass(
- SQLConf.OUTPUT_COMMITTER_CLASS,
+ SQLConf.OUTPUT_COMMITTER_CLASS.key,
committerClass,
classOf[ParquetOutputCommitter])
@@ -258,8 +257,8 @@ private[sql] class ParquetRelation2(
requiredColumns: Array[String],
filters: Array[Filter],
inputFiles: Array[FileStatus],
- broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = {
- val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean
+ broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = {
+ val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA)
val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown
// Create the function to set variable Parquet confs at both driver and executor side.
val initLocalJobFuncOpt =
@@ -498,7 +497,7 @@ private[sql] object ParquetRelation2 extends Logging {
ParquetTypesConverter.convertToString(dataSchema.toAttributes))
// Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata
- conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString)
+ conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache)
}
/** This closure sets input paths at the driver side. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 4cf67439b9b8d..a8f56f4767407 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.sources
+import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
import org.apache.spark.sql._
@@ -27,9 +28,8 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{SaveMode, Strategy, execution, sources}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.{Logging, SerializableWritable, TaskContext}
/**
* A Strategy for planning scans over data sources defined using the sources API.
@@ -91,7 +91,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// broadcast HadoopConf.
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
- t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf))
+ t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
pruneFilterProject(
l,
projects,
@@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Otherwise, the cost of broadcasting HadoopConf in every RDD will be high.
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
- relation.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf))
+ relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
// Builds RDD[Row]s for each selected partition.
val perPartitionRows = partitions.map { case Partition(partitionValues, dir) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
index ebad0c1564ec0..2bdc341021256 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
@@ -34,7 +34,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{RDD, HadoopRDD}
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, Utils}
import scala.reflect.ClassTag
@@ -65,7 +65,7 @@ private[spark] class SqlNewHadoopPartition(
*/
private[sql] class SqlNewHadoopRDD[K, V](
@transient sc : SparkContext,
- broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ broadcastedConf: Broadcast[SerializableConfiguration],
@transient initDriverSideJobFuncOpt: Option[Job => Unit],
initLocalJobFuncOpt: Option[Job => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 3dbe6faabf453..c16bd9ae52c81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
+import org.apache.spark.util.SerializableConfiguration
private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
@@ -260,7 +261,7 @@ private[sql] abstract class BaseWriterContainer(
with Logging
with Serializable {
- protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job))
+ protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job))
// This is only used on driver side.
@transient private val jobContext: JobContext = job
@@ -323,7 +324,7 @@ private[sql] abstract class BaseWriterContainer(
private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = {
val committerClass = context.getConfiguration.getClass(
- SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter])
+ SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
Option(committerClass).map { clazz =>
logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 43d3507d7d2ba..7005c7079af91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -27,12 +27,12 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.SerializableWritable
import org.apache.spark.sql.execution.RDDConversions
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
/**
* ::DeveloperApi::
@@ -518,7 +518,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
requiredColumns: Array[String],
filters: Array[Filter],
inputPaths: Array[String],
- broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = {
+ broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = {
val inputStatuses = inputPaths.flatMap { input =>
val path = new Path(input)
@@ -648,7 +648,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
requiredColumns: Array[String],
filters: Array[Filter],
inputFiles: Array[FileStatus],
- broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = {
+ broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = {
buildScan(requiredColumns, filters, inputFiles)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 356a6100d2cf5..9fa394525d65c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -38,7 +38,7 @@ class LocalSQLContext
protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
- override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+ override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 5a08578e7ba4b..88bb743ab0bc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -296,6 +296,22 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer)
}
+ test("in") {
+ val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
+ checkAnswer(df.filter($"a".in(1, 2)),
+ df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
+ checkAnswer(df.filter($"a".in(3, 2)),
+ df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
+ checkAnswer(df.filter($"a".in(3, 1)),
+ df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
+ checkAnswer(df.filter($"b".in("y", "x")),
+ df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
+ checkAnswer(df.filter($"b".in("z", "x")),
+ df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
+ checkAnswer(df.filter($"b".in("z", "y")),
+ df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
+ }
+
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 790b405c72697..b26d3ab253a1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -68,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest {
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
- ctx.conf.setConf("spark.sql.retainGroupColumns", "false")
+ ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false)
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
- ctx.conf.setConf("spark.sql.retainGroupColumns", "true")
+ ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true)
}
test("agg without groups") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index fa98e23e3d147..ba1d020f22f11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -33,7 +33,7 @@ class DataFrameSuite extends QueryTest {
test("analysis error should be eagerly reported") {
val oldSetting = ctx.conf.dataFrameEagerAnalysis
// Eager analysis.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)
intercept[Exception] { testData.select('nonExistentName) }
intercept[Exception] {
@@ -47,11 +47,11 @@ class DataFrameSuite extends QueryTest {
}
// No more eager analysis once the flag is turned off
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
testData.select('nonExistentName)
// Set the flag back to original value before this test.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
}
test("dataframe toString") {
@@ -70,7 +70,7 @@ class DataFrameSuite extends QueryTest {
test("invalid plan toString, debug mode") {
val oldSetting = ctx.conf.dataFrameEagerAnalysis
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)
// Turn on debug mode so we can see invalid query plans.
import org.apache.spark.sql.execution.debug._
@@ -83,7 +83,7 @@ class DataFrameSuite extends QueryTest {
badPlan.toString)
// Set the flag back to original value before this test.
- ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
}
test("access complex data") {
@@ -556,13 +556,13 @@ class DataFrameSuite extends QueryTest {
test("SPARK-6899") {
val originalValue = ctx.conf.codegenEnabled
- ctx.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ ctx.setConf(SQLConf.CODEGEN_ENABLED, true)
try{
checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
} finally {
- ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index ffd26c4f5a7c2..20390a5544304 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -95,14 +95,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
try {
- ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
Seq(
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
- ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
}
}
@@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
try {
- ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2",
@@ -127,7 +127,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
- ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
}
ctx.sql("UNCACHE TABLE testData")
@@ -416,7 +416,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
ctx.sql("CACHE TABLE testData")
val tmp = ctx.conf.autoBroadcastJoinThreshold
- ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
+ ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[BroadcastLeftSemiJoinHash])
@@ -424,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case (query, joinClass) => assertJoin(query, joinClass)
}
- ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+ ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
@@ -432,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case (query, joinClass) => assertJoin(query, joinClass)
}
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp)
ctx.sql("UNCACHE TABLE testData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index e2daaf6b730c5..2768d7dfc8030 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("binary log") {
+ val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
+ checkAnswer(
+ df.select(org.apache.spark.sql.functions.log("a"),
+ org.apache.spark.sql.functions.log(2.0, "a"),
+ org.apache.spark.sql.functions.log("b")),
+ Row(math.log(123), math.log(123) / math.log(2), null))
+
+ checkAnswer(
+ df.selectExpr("log(a)", "log(2.0, a)", "log(b)"),
+ Row(math.log(123), math.log(123) / math.log(2), null))
+ }
+
test("abs") {
val input =
Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5))
@@ -257,6 +270,16 @@ class MathExpressionsSuite extends QueryTest {
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
}
+ test("sqrt") {
+ val df = Seq((1, 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(sqrt("a"), sqrt("b")),
+ Row(1.0, 2.0))
+
+ checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
+ checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
+ }
+
test("negative") {
checkAnswer(
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala
new file mode 100644
index 0000000000000..2e33777f14adc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SQLConf._
+
+class SQLConfEntrySuite extends SparkFunSuite {
+
+ val conf = new SQLConf
+
+ test("intConf") {
+ val key = "spark.sql.SQLConfEntrySuite.int"
+ val confEntry = SQLConfEntry.intConf(key)
+ assert(conf.getConf(confEntry, 5) === 5)
+
+ conf.setConf(confEntry, 10)
+ assert(conf.getConf(confEntry, 5) === 10)
+
+ conf.setConfString(key, "20")
+ assert(conf.getConfString(key, "5") === "20")
+ assert(conf.getConfString(key) === "20")
+ assert(conf.getConf(confEntry, 5) === 20)
+
+ val e = intercept[IllegalArgumentException] {
+ conf.setConfString(key, "abc")
+ }
+ assert(e.getMessage === s"$key should be int, but was abc")
+ }
+
+ test("longConf") {
+ val key = "spark.sql.SQLConfEntrySuite.long"
+ val confEntry = SQLConfEntry.longConf(key)
+ assert(conf.getConf(confEntry, 5L) === 5L)
+
+ conf.setConf(confEntry, 10L)
+ assert(conf.getConf(confEntry, 5L) === 10L)
+
+ conf.setConfString(key, "20")
+ assert(conf.getConfString(key, "5") === "20")
+ assert(conf.getConfString(key) === "20")
+ assert(conf.getConf(confEntry, 5L) === 20L)
+
+ val e = intercept[IllegalArgumentException] {
+ conf.setConfString(key, "abc")
+ }
+ assert(e.getMessage === s"$key should be long, but was abc")
+ }
+
+ test("booleanConf") {
+ val key = "spark.sql.SQLConfEntrySuite.boolean"
+ val confEntry = SQLConfEntry.booleanConf(key)
+ assert(conf.getConf(confEntry, false) === false)
+
+ conf.setConf(confEntry, true)
+ assert(conf.getConf(confEntry, false) === true)
+
+ conf.setConfString(key, "true")
+ assert(conf.getConfString(key, "false") === "true")
+ assert(conf.getConfString(key) === "true")
+ assert(conf.getConf(confEntry, false) === true)
+
+ val e = intercept[IllegalArgumentException] {
+ conf.setConfString(key, "abc")
+ }
+ assert(e.getMessage === s"$key should be boolean, but was abc")
+ }
+
+ test("doubleConf") {
+ val key = "spark.sql.SQLConfEntrySuite.double"
+ val confEntry = SQLConfEntry.doubleConf(key)
+ assert(conf.getConf(confEntry, 5.0) === 5.0)
+
+ conf.setConf(confEntry, 10.0)
+ assert(conf.getConf(confEntry, 5.0) === 10.0)
+
+ conf.setConfString(key, "20.0")
+ assert(conf.getConfString(key, "5.0") === "20.0")
+ assert(conf.getConfString(key) === "20.0")
+ assert(conf.getConf(confEntry, 5.0) === 20.0)
+
+ val e = intercept[IllegalArgumentException] {
+ conf.setConfString(key, "abc")
+ }
+ assert(e.getMessage === s"$key should be double, but was abc")
+ }
+
+ test("stringConf") {
+ val key = "spark.sql.SQLConfEntrySuite.string"
+ val confEntry = SQLConfEntry.stringConf(key)
+ assert(conf.getConf(confEntry, "abc") === "abc")
+
+ conf.setConf(confEntry, "abcd")
+ assert(conf.getConf(confEntry, "abc") === "abcd")
+
+ conf.setConfString(key, "abcde")
+ assert(conf.getConfString(key, "abc") === "abcde")
+ assert(conf.getConfString(key) === "abcde")
+ assert(conf.getConf(confEntry, "abc") === "abcde")
+ }
+
+ test("enumConf") {
+ val key = "spark.sql.SQLConfEntrySuite.enum"
+ val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a"))
+ assert(conf.getConf(confEntry) === "a")
+
+ conf.setConf(confEntry, "b")
+ assert(conf.getConf(confEntry) === "b")
+
+ conf.setConfString(key, "c")
+ assert(conf.getConfString(key, "a") === "c")
+ assert(conf.getConfString(key) === "c")
+ assert(conf.getConf(confEntry) === "c")
+
+ val e = intercept[IllegalArgumentException] {
+ conf.setConfString(key, "d")
+ }
+ assert(e.getMessage === s"The value of $key should be one of a, b, c, but was d")
+ }
+
+ test("stringSeqConf") {
+ val key = "spark.sql.SQLConfEntrySuite.stringSeq"
+ val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq",
+ defaultValue = Some(Nil))
+ assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c"))
+
+ conf.setConf(confEntry, Seq("a", "b", "c", "d"))
+ assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d"))
+
+ conf.setConfString(key, "a,b,c,d,e")
+ assert(conf.getConfString(key, "a,b,c") === "a,b,c,d,e")
+ assert(conf.getConfString(key) === "a,b,c,d,e")
+ assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 76d0dd1744a41..75791e9d53c20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -75,6 +75,14 @@ class SQLConfSuite extends QueryTest {
test("deprecated property") {
ctx.conf.clear()
ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
- assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10")
+ assert(ctx.conf.numShufflePartitions === 10)
+ }
+
+ test("invalid conf value") {
+ ctx.conf.clear()
+ val e = intercept[IllegalArgumentException] {
+ ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
+ }
+ assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 30db840166ca6..82f3fdb48b557 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -190,7 +190,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("aggregation with codegen") {
val originalValue = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
// Prepare a table that we can group some rows.
sqlContext.table("testData")
.unionAll(sqlContext.table("testData"))
@@ -287,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(0, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
}
}
@@ -480,41 +480,41 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("sorting") {
val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false")
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
}
test("external sorting") {
val before = sqlContext.conf.externalSortEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true")
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
sortTest()
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, before)
}
test("SPARK-6927 sorting with codegen on") {
val externalbefore = sqlContext.conf.externalSortEnabled
val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false")
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, false)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
try{
sortTest()
} finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
test("SPARK-6927 external sorting with codegen on") {
val externalbefore = sqlContext.conf.externalSortEnabled
val codegenbefore = sqlContext.conf.codegenEnabled
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true")
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, true)
try {
sortTest()
} finally {
- sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString)
- sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString)
+ sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore)
+ sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore)
}
}
@@ -908,25 +908,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Row(s"$testKey=$testVal")
+ Row(testKey, testVal)
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Row(s"$testKey=$testVal"),
- Row(s"${testKey + testKey}=${testVal + testVal}"))
+ Row(testKey, testVal),
+ Row(testKey + testKey, testVal + testVal))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Row(s"$testKey=$testVal")
+ Row(testKey, testVal)
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Row(s"$nonexistentKey=")
+ Row(nonexistentKey, "")
)
sqlContext.conf.clear()
}
@@ -1340,12 +1340,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("SPARK-4699 case sensitivity SQL query") {
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false")
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("testTable1")
checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true")
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, true)
}
test("SPARK-6145: ORDER BY test for nested fields") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 6545c6b314a4c..2c0879927a129 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -32,7 +32,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
- ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
@@ -41,14 +41,14 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
- ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
- ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
- ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
}
before {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3e27f58a92d01..5854ab48db552 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -63,7 +63,7 @@ class PlannerSuite extends SparkFunSuite {
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString)
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
val fields = fieldTypes.zipWithIndex.map {
case (dataType, index) => StructField(s"c${index}", dataType, true)
} :+ StructField("key", IntegerType, true)
@@ -119,12 +119,12 @@ class PlannerSuite extends SparkFunSuite {
checkPlan(complexTypes, newThreshold = 901617)
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("InMemoryRelation statistics propagation") {
val origThreshold = conf.autoBroadcastJoinThreshold
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
@@ -139,6 +139,6 @@ class PlannerSuite extends SparkFunSuite {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
new file mode 100644
index 0000000000000..a1e3ca11b1ad9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class SortSuite extends SparkPlanTest {
+
+ // This test was originally added as an example of how to use [[SparkPlanTest]];
+ // it's not designed to be a comprehensive test of ExternalSort.
+ test("basic sorting using ExternalSort") {
+
+ val input = Seq(
+ ("Hello", 4, 2.0),
+ ("Hello", 1, 1.0),
+ ("World", 8, 3.0)
+ )
+
+ checkAnswer(
+ input.toDF("a", "b", "c"),
+ ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
+ input.sorted)
+
+ checkAnswer(
+ input.toDF("a", "b", "c"),
+ ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
+ input.sortBy(t => (t._2, t._1)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
new file mode 100644
index 0000000000000..13f3be8ca28d6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkFunSuite
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.BoundReference
+import org.apache.spark.sql.catalyst.util._
+
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame}
+
+/**
+ * Base class for writing tests for individual physical operators. For an example of how this
+ * class's test helper methods can be used, see [[SortSuite]].
+ */
+class SparkPlanTest extends SparkFunSuite {
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ */
+ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
+ TestSQLContext.implicits.localSeqToDataFrameHolder(data)
+ }
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+ * the physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ protected def checkAnswer(
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedAnswer: Seq[Row]): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+ * the physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ */
+ protected def checkAnswer[A <: Product : TypeTag](
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedAnswer: Seq[A]): Unit = {
+ val expectedRows = expectedAnswer.map(Row.fromTuple)
+ SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+}
+
+/**
+ * Helper methods for writing tests of individual physical operators.
+ */
+object SparkPlanTest {
+
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+ * the physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedAnswer: Seq[Row]): Option[String] = {
+
+ val outputPlan = planFunction(input.queryExecution.sparkPlan)
+
+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
+ val resolvedPlan = outputPlan transform {
+ case plan: SparkPlan =>
+ val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
+ case (a, i) =>
+ (a.name, BoundReference(i, a.dataType, a.nullable))
+ }.toMap
+
+ plan.transformExpressions {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
+ }
+
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ // This function is copied from Catalyst's QueryTest
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case b: Array[Byte] => b.toSeq
+ case o => o
+ })
+ }
+ converted.sortBy(_.toString())
+ }
+
+ val sparkAnswer: Seq[Row] = try {
+ resolvedPlan.executeCollect().toSeq
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
+ val errorMessage =
+ s"""
+ | Results do not match for Spark plan:
+ | $outputPlan
+ | == Results ==
+ | ${sideBySide(
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString()),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ None
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index fca24364fe6ec..945d4375035fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -1077,14 +1077,14 @@ class JsonSuite extends QueryTest with TestJsonData {
}
test("SPARK-7565 MapType in JsonRDD") {
- val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
+ val useStreaming = ctx.conf.useJacksonStreamingAPI
val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
try{
- for (useStreaming <- List("true", "false")) {
+ for (useStreaming <- List(true, false)) {
ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
val temp = Utils.createTempDir().getPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index fa5d4eca05d9f..a2763c78b6450 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -51,7 +51,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
expected: Seq[Row]): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
- withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
val query = df
.select(output.map(e => Column(e)): _*)
.where(Column(predicate))
@@ -314,17 +314,17 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA
lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
test("SPARK-6554: don't push down predicates which reference partition columns") {
import sqlContext.implicits._
- withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/part=1"
(1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
@@ -343,17 +343,17 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before
lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
test("SPARK-6742: don't push down predicates which reference partition columns") {
import sqlContext.implicits._
- withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/part=1"
(1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index fc827bc4ca11b..284d99d4938d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -94,8 +94,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val data = (1 to 4).map(i => Tuple1(i.toString))
// Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL
// as we store Spark SQL schema in the extra metadata.
- withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data))
- withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data))
+ withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "false")(checkParquetFile(data))
+ withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data))
}
test("fixed-length decimals") {
@@ -231,7 +231,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val data = (0 until 10).map(i => (i, i.toString))
def checkCompressionCodec(codec: CompressionCodecName): Unit = {
- withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
+ withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) {
withParquetFile(data) { path =>
assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) {
compressionCodecFor(path)
@@ -408,7 +408,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val clonedConf = new Configuration(configuration)
configuration.set(
- SQLConf.OUTPUT_COMMITTER_CLASS, classOf[ParquetOutputCommitter].getCanonicalName)
+ SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName)
configuration.set(
"spark.sql.parquet.output.committer.class",
@@ -440,11 +440,11 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString)
}
test("SPARK-6330 regression test") {
@@ -464,10 +464,10 @@ class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfter
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index be3b34d5b9b70..fafad67dde3a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -128,11 +128,11 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
}
@@ -140,10 +140,10 @@ class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAn
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
- sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
}
override protected def afterAll(): Unit = {
- sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 3f77960d09246..00cc7d5ea580f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -27,7 +27,7 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
// We want to test some edge cases.
protected implicit lazy val caseInsensitiveContext = {
val ctx = new SQLContext(TestSQLContext.sparkContext)
- ctx.setConf(SQLConf.CASE_SENSITIVE, "false")
+ ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index ac4a00a6f3dac..fa01823e9417c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -37,11 +37,11 @@ trait SQLTestUtils {
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption)
- (keys, values).zipped.foreach(sqlContext.conf.setConf)
+ val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
+ (keys, values).zipped.foreach(sqlContext.conf.setConfString)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => sqlContext.conf.setConf(key, value)
+ case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
case (key, None) => sqlContext.conf.unsetConf(key)
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index c9da25253e13f..700d994bb6a83 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -153,9 +153,9 @@ object HiveThriftServer2 extends Logging {
val sessionList = new mutable.LinkedHashMap[String, SessionInfo]
val executionList = new mutable.LinkedHashMap[String, ExecutionInfo]
val retainedStatements =
- conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT, "200").toInt
+ conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT)
val retainedSessions =
- conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT, "200").toInt
+ conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT)
var totalRunning = 0
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index e071103df925c..e8758887ff3a2 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -219,7 +219,7 @@ private[hive] class SparkExecuteStatementOperation(
result = hiveContext.sql(statement)
logDebug(result.queryExecution.toString())
result.queryExecution.logical match {
- case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) =>
+ case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) =>
sessionToActivePool(parentSession.getSessionHandle) = value
logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
case _ =>
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 178bd1f5cb164..301aa5a6411e2 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -113,8 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
- assert(resultSet.getString(1) ===
- s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}")
+ assert(resultSet.getString(1) === "spark.sql.hive.version")
+ assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion)
}
}
@@ -238,7 +238,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
// first session, we get the default value of the session status
{ statement =>
- val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
+ val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}")
rs1.next()
defaultV1 = rs1.getString(1)
assert(defaultV1 != "200")
@@ -256,19 +256,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
{ statement =>
val queries = Seq(
- s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291",
+ s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291",
"SET hive.cli.print.header=true"
)
queries.map(statement.execute)
- val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
+ val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}")
rs1.next()
- assert("spark.sql.shuffle.partitions=291" === rs1.getString(1))
+ assert("spark.sql.shuffle.partitions" === rs1.getString(1))
+ assert("291" === rs1.getString(2))
rs1.close()
val rs2 = statement.executeQuery("SET hive.cli.print.header")
rs2.next()
- assert("hive.cli.print.header=true" === rs2.getString(1))
+ assert("hive.cli.print.header" === rs2.getString(1))
+ assert("true" === rs2.getString(2))
rs2.close()
},
@@ -276,7 +278,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
// default value
{ statement =>
- val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
+ val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}")
rs1.next()
assert(defaultV1 === rs1.getString(1))
rs1.close()
@@ -404,8 +406,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
- assert(resultSet.getString(1) ===
- s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}")
+ assert(resultSet.getString(1) === "spark.sql.hive.version")
+ assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion)
}
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 82c0b494598a8..f88e62763ca70 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -47,17 +47,17 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Add Locale setting
Locale.setDefault(Locale.US)
// Set a relatively small column batch size for testing purposes
- TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5")
+ TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
// Enable in-memory partition pruning for testing purposes
- TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
}
override def afterAll() {
TestHive.cacheTables = false
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
- TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
- TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+ TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
}
/** A list of tests deemed out of scope currently and thus completely disregarded. */
@@ -933,7 +933,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_stddev_pop",
"udf_stddev_samp",
"udf_string",
- // "udf_struct", TODO: FIX THIS and enable it.
+ "udf_struct",
"udf_substring",
"udf_subtract",
"udf_sum",
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
index 65d070bd3cbde..f458567e5d7ea 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive
class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
override def beforeAll() {
super.beforeAll()
- TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, true)
}
override def afterAll() {
- TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, false)
super.afterAll()
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index c50835dd8f11d..4a66d6508ae0a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -21,15 +21,13 @@ import java.io.File
import java.net.{URL, URLClassLoader}
import java.sql.Timestamp
-import org.apache.hadoop.hive.common.StatsSetupConst
-import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.spark.sql.catalyst.ParserDialect
-
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.language.implicitConversions
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hive.common.StatsSetupConst
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse.VariableSubstitution
@@ -39,6 +37,9 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
+import org.apache.spark.sql.SQLConf.SQLConfEntry
+import org.apache.spark.sql.SQLConf.SQLConfEntry._
+import org.apache.spark.sql.catalyst.ParserDialect
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
@@ -69,13 +70,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
import HiveContext._
+ println("create HiveContext")
+
/**
* When true, enables an experimental feature where metastore tables that use the parquet SerDe
* are automatically converted to use the Spark SQL parquet table scan, instead of the Hive
* SerDe.
*/
- protected[sql] def convertMetastoreParquet: Boolean =
- getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"
+ protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET)
/**
* When true, also tries to merge possibly different but compatible Parquet schemas in different
@@ -84,7 +86,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true.
*/
protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean =
- getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true"
+ getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING)
/**
* When true, a table created by a Hive CTAS statement (no USING clause) will be
@@ -98,8 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format
* and no SerDe is specified (no ROW FORMAT SERDE clause).
*/
- protected[sql] def convertCTAS: Boolean =
- getConf("spark.sql.hive.convertCTAS", "false").toBoolean
+ protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS)
/**
* The version of the hive client that will be used to communicate with the metastore. Note that
@@ -117,8 +118,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* option is only valid when using the execution version of Hive.
* - maven - download the correct version of hive on demand from maven.
*/
- protected[hive] def hiveMetastoreJars: String =
- getConf(HIVE_METASTORE_JARS, "builtin")
+ protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS)
/**
* A comma separated list of class prefixes that should be loaded using the classloader that
@@ -128,11 +128,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* custom appenders that are used by log4j.
*/
protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] =
- getConf("spark.sql.hive.metastore.sharedPrefixes", jdbcPrefixes)
- .split(",").filterNot(_ == "")
-
- private def jdbcPrefixes = Seq(
- "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc").mkString(",")
+ getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "")
/**
* A comma separated list of class prefixes that should explicitly be reloaded for each version
@@ -140,14 +136,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* prefix that typically would be shared (i.e. org.apache.spark.*)
*/
protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] =
- getConf("spark.sql.hive.metastore.barrierPrefixes", "")
- .split(",").filterNot(_ == "")
+ getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "")
/*
* hive thrift server use background spark sql thread pool to execute sql queries
*/
- protected[hive] def hiveThriftServerAsync: Boolean =
- getConf("spark.sql.hive.thriftServer.async", "true").toBoolean
+ protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC)
@transient
protected[sql] lazy val substitutor = new VariableSubstitution()
@@ -364,7 +358,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
hiveconf.set(key, value)
}
- /* A catalyst metadata catalog that points to the Hive Metastore. */
+ private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
+ setConf(entry.key, entry.stringConverter(value))
+ }
+
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog =
new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog
@@ -402,8 +400,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected[hive] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean =
- getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
}
/**
@@ -519,7 +516,50 @@ private[hive] object HiveContext {
val hiveExecutionVersion: String = "0.13.1"
val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version"
- val HIVE_METASTORE_JARS: String = "spark.sql.hive.metastore.jars"
+ val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars",
+ defaultValue = Some("builtin"),
+ doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" +
+ " property can be one of three options: " +
+ "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " +
+ "-Phive is enabled. When this option is chosen, " +
+ "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " +
+ "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." +
+ "3. A classpath in the standard format for both Hive and Hadoop.")
+
+ val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet",
+ defaultValue = Some(true),
+ doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " +
+ "the built in support.")
+
+ val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf(
+ "spark.sql.hive.convertMetastoreParquet.mergeSchema",
+ defaultValue = Some(false),
+ doc = "TODO")
+
+ val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS",
+ defaultValue = Some(false),
+ doc = "TODO")
+
+ val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes",
+ defaultValue = Some(jdbcPrefixes),
+ doc = "A comma separated list of class prefixes that should be loaded using the classloader " +
+ "that is shared between Spark SQL and a specific version of Hive. An example of classes " +
+ "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " +
+ "classes that need to be shared are those that interact with classes that are already " +
+ "shared. For example, custom appenders that are used by log4j.")
+
+ private def jdbcPrefixes = Seq(
+ "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc")
+
+ val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes",
+ defaultValue = Some(Seq()),
+ doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " +
+ "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " +
+ "declared in a prefix that typically would be shared (i.e. org.apache.spark.* ).")
+
+ val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async",
+ defaultValue = Some(true),
+ doc = "TODO")
/** Constructs a configuration for hive, where the metastore is located in a temp directory. */
def newTemporaryConfiguration(): Map[String, String] = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 485810320f3c1..439f39bafc926 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.hive
-import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._
@@ -30,12 +29,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters,
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
-import org.apache.spark.{Logging, SerializableWritable}
+import org.apache.spark.{Logging}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
* A trait for subclasses that handle table scans.
@@ -72,7 +71,7 @@ class HadoopTableReader(
// TODO: set aws s3 credentials.
private val _broadcastedHiveConf =
- sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf))
+ sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf))
override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
@@ -276,7 +275,7 @@ class HadoopTableReader(
val rdd = new HadoopRDD(
sc.sparkContext,
- _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]],
Some(initializeJobConfFunc),
inputFormatClass,
classOf[Writable],
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 1d306c5d10af8..404bb937aaf87 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -35,9 +35,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow}
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.hive._
-import org.apache.spark.{SerializableWritable, SparkException, TaskContext}
+import org.apache.spark.{SparkException, TaskContext}
import scala.collection.JavaConversions._
+import org.apache.spark.util.SerializableJobConf
private[hive]
case class InsertIntoHiveTable(
@@ -64,7 +65,7 @@ case class InsertIntoHiveTable(
rdd: RDD[InternalRow],
valueClass: Class[_],
fileSinkConf: FileSinkDesc,
- conf: SerializableWritable[JobConf],
+ conf: SerializableJobConf,
writerContainer: SparkHiveWriterContainer): Unit = {
assert(valueClass != null, "Output value class not set")
conf.value.setOutputValueClass(valueClass)
@@ -172,7 +173,7 @@ case class InsertIntoHiveTable(
}
val jobConf = new JobConf(sc.hiveconf)
- val jobConfSer = new SerializableWritable(jobConf)
+ val jobConfSer = new SerializableJobConf(jobConf)
val writerContainer = if (numDynamicPartitions > 0) {
val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index ee440e304ec19..0bc69c00c241c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -37,6 +37,7 @@ import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.types._
+import org.apache.spark.util.SerializableJobConf
/**
* Internal helper class that saves an RDD using a Hive OutputFormat.
@@ -57,7 +58,7 @@ private[hive] class SparkHiveWriterContainer(
PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc)
Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf)
}
- protected val conf = new SerializableWritable(jobConf)
+ protected val conf = new SerializableJobConf(jobConf)
private var jobID = 0
private var splitID = 0
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index f03c4cd54e7e6..dbce39f21d271 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -39,7 +39,8 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType
import org.apache.spark.sql.sources.{Filter, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.{Logging, SerializableWritable}
+import org.apache.spark.{Logging}
+import org.apache.spark.util.SerializableConfiguration
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -110,7 +111,7 @@ private[orc] class OrcOutputWriter(
new OrcOutputFormat().getRecordWriter(
new Path(path, filename).getFileSystem(conf),
conf.asInstanceOf[JobConf],
- new Path(path, filename).toUri.getPath,
+ new Path(path, filename).toString,
Reporter.NULL
).asInstanceOf[RecordWriter[NullWritable, Writable]]
}
@@ -283,7 +284,7 @@ private[orc] case class OrcTableScan(
classOf[Writable]
).asInstanceOf[HadoopRDD[NullWritable, Writable]]
- val wrappedConf = new SerializableWritable(conf)
+ val wrappedConf = new SerializableConfiguration(conf)
rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) =>
val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 92155096202b3..f901bd8171508 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -112,12 +112,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
protected[hive] class SQLSession extends super.SQLSession {
/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
- override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+ override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
// TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
// The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean =
- getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index a0d80dc39c108..af68615e8e9d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -81,11 +81,11 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
}
}
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
+ withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") {
run("Parquet data source enabled")
}
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") {
+ withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") {
run("Parquet data source disabled")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 79a85b24d2f60..cc294bc3e8bc3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -456,7 +456,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
withTable("savedJsonTable") {
val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b")
- withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") {
+ withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") {
// Save the df as a managed table (by not specifying the path).
df.write.saveAsTable("savedJsonTable")
@@ -484,7 +484,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
}
// Create an external table by specifying the path.
- withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") {
+ withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") {
df.write
.format("org.apache.spark.sql.json")
.mode(SaveMode.Append)
@@ -508,7 +508,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
s"""{ "a": $i, "b": "str$i" }"""
}))
- withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") {
+ withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") {
df.write
.format("json")
.mode(SaveMode.Append)
@@ -516,7 +516,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
.saveAsTable("savedJsonTable")
}
- withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") {
+ withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") {
createExternalTable("createdJsonTable", tempPath.toString)
assert(table("createdJsonTable").schema === df.schema)
checkAnswer(sql("SELECT * FROM createdJsonTable"), df)
@@ -533,7 +533,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
checkAnswer(read.json(tempPath.toString), df)
// Try to specify the schema.
- withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") {
+ withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") {
val schema = StructType(StructField("b", StringType, true) :: Nil)
createExternalTable(
"createdJsonTable",
@@ -563,8 +563,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
test("scan a parquet table created through a CTAS statement") {
withSQLConf(
- "spark.sql.hive.convertMetastoreParquet" -> "true",
- SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
+ HiveContext.CONVERT_METASTORE_PARQUET.key -> "true",
+ SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") {
withTempTable("jt") {
(1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt")
@@ -706,7 +706,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA
}
test("SPARK-6024 wide schema support") {
- withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") {
+ withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") {
withTable("wide_schema") {
// We will need 80 splits for this schema if the threshold is 4000.
val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true)))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 78c94e6490e36..f067ea0d4fc75 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -167,7 +167,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
ctx.conf.settings.synchronized {
val tmp = ctx.conf.autoBroadcastJoinThreshold
- sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
+ sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""")
df = sql(query)
bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
@@ -176,7 +176,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
assert(shj.size === 1,
"ShuffledHashJoin should be planned when BroadcastHashJoin is turned off")
- sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""")
+ sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""")
}
after()
@@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
ctx.conf.settings.synchronized {
val tmp = ctx.conf.autoBroadcastJoinThreshold
- sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
df = sql(leftSemiJoinQuery)
bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
@@ -238,7 +238,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
assert(shj.size === 1,
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
- sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp")
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6d8d99ebc8164..51dabc67fa7c1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -1084,14 +1084,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
- val KV = "([^=]+)=([^=]*)".r
- def collectResults(df: DataFrame): Set[(String, String)] =
+ def collectResults(df: DataFrame): Set[Any] =
df.collect().map {
case Row(key: String, value: String) => key -> value
- case Row(KV(key, value)) => key -> value
+ case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc)
}.toSet
conf.clear()
+ val expectedConfs = conf.getAllDefinedConfs.toSet
+ assertResult(expectedConfs)(collectResults(sql("SET -v")))
+
// "SET" itself returns all config variables currently specified in SQLConf.
// TODO: Should we be listing the default here always? probably...
assert(sql("SET").collect().size == 0)
@@ -1102,16 +1104,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
assert(hiveconf.get(testKey, "") == testVal)
assertResult(Set(testKey -> testVal))(collectResults(sql("SET")))
- assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v")))
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
collectResults(sql("SET"))
}
- assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(sql("SET -v"))
- }
// "SET key"
assertResult(Set(testKey -> testVal)) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 984d97d27bf54..e1c9926bed524 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
-import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation}
+import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation}
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.types._
@@ -191,9 +191,9 @@ class SQLQuerySuite extends QueryTest {
}
}
- val originalConf = getConf("spark.sql.hive.convertCTAS", "false")
+ val originalConf = convertCTAS
- setConf("spark.sql.hive.convertCTAS", "true")
+ setConf(HiveContext.CONVERT_CTAS, true)
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
@@ -235,7 +235,7 @@ class SQLQuerySuite extends QueryTest {
checkRelation("ctas1", false)
sql("DROP TABLE ctas1")
- setConf("spark.sql.hive.convertCTAS", originalConf)
+ setConf(HiveContext.CONVERT_CTAS, originalConf)
}
test("SQL Dialect Switching") {
@@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest {
val origUseParquetDataSource = conf.parquetUseDataSourceApi
try {
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
sql(
"""CREATE TABLE ctas5
| STORED AS parquet AS
@@ -348,7 +348,7 @@ class SQLQuerySuite extends QueryTest {
"MANAGED_TABLE"
)
- val default = getConf("spark.sql.hive.convertMetastoreParquet", "true")
+ val default = convertMetastoreParquet
// use the Hive SerDe for parquet tables
sql("set spark.sql.hive.convertMetastoreParquet = false")
checkAnswer(
@@ -356,7 +356,7 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq)
sql(s"set spark.sql.hive.convertMetastoreParquet = $default")
} finally {
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource)
}
}
@@ -603,8 +603,8 @@ class SQLQuerySuite extends QueryTest {
// generates an invalid query plan.
val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}"""))
read.json(rdd).registerTempTable("data")
- val originalConf = getConf("spark.sql.hive.convertCTAS", "false")
- setConf("spark.sql.hive.convertCTAS", "false")
+ val originalConf = convertCTAS
+ setConf(HiveContext.CONVERT_CTAS, false)
sql("CREATE TABLE explodeTest (key bigInt)")
table("explodeTest").queryExecution.analyzed match {
@@ -621,7 +621,7 @@ class SQLQuerySuite extends QueryTest {
sql("DROP TABLE explodeTest")
dropTempTable("data")
- setConf("spark.sql.hive.convertCTAS", originalConf)
+ setConf(HiveContext.CONVERT_CTAS, originalConf)
}
test("sanity test for SPARK-6618") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 3864349cdbd89..c2e09800933b5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -153,7 +153,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}"""))
read.json(rdd2).registerTempTable("jt_array")
- setConf("spark.sql.hive.convertMetastoreParquet", "true")
+ setConf(HiveContext.CONVERT_METASTORE_PARQUET, true)
}
override def afterAll(): Unit = {
@@ -164,7 +164,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
sql("DROP TABLE normal_parquet")
sql("DROP TABLE IF EXISTS jt")
sql("DROP TABLE IF EXISTS jt_array")
- setConf("spark.sql.hive.convertMetastoreParquet", "false")
+ setConf(HiveContext.CONVERT_METASTORE_PARQUET, false)
}
test(s"conversion is working") {
@@ -199,14 +199,14 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase {
| OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
""".stripMargin)
- conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true)
}
override def afterAll(): Unit = {
super.afterAll()
sql("DROP TABLE IF EXISTS test_parquet")
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
test("scan an empty parquet table") {
@@ -546,12 +546,12 @@ class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
- conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
}
override def afterAll(): Unit = {
super.afterAll()
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
test("MetastoreRelation in InsertIntoTable will not be converted") {
@@ -692,12 +692,12 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
- conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true)
}
override def afterAll(): Unit = {
super.afterAll()
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
test("values in arrays and maps stored in parquet are always nullable") {
@@ -750,12 +750,12 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase {
override def beforeAll(): Unit = {
super.beforeAll()
- conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false)
}
override def afterAll(): Unit = {
super.afterAll()
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 6c1fab56740ee..86a8e2beff57c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -26,10 +26,9 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.spark.{SparkConf, SerializableWritable}
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming._
-import org.apache.spark.util.{TimeStampedHashMap, Utils}
+import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils}
/**
* This class represents an input stream that monitors a Hadoop-compatible filesystem for new
@@ -78,7 +77,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]](
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
extends InputDStream[(K, V)](ssc_) {
- private val serializableConfOpt = conf.map(new SerializableWritable(_))
+ private val serializableConfOpt = conf.map(new SerializableConfiguration(_))
/**
* Minimum duration of remembering the information of selected files. Defaults to 60 seconds.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 358e4c66df7ba..71bec96d46c8d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -24,10 +24,11 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable}
+import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.streaming.StreamingContext.rddToFileName
+import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf}
/**
* Extra functions available on DStream of (key, value) pairs through an implicit conversion.
@@ -688,7 +689,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration)
): Unit = ssc.withScope {
// Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints
- val serializableConf = new SerializableWritable(conf)
+ val serializableConf = new SerializableJobConf(conf)
val saveFunc = (rdd: RDD[(K, V)], time: Time) => {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value)
@@ -721,7 +722,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
conf: Configuration = ssc.sparkContext.hadoopConfiguration
): Unit = ssc.withScope {
// Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints
- val serializableConf = new SerializableWritable(conf)
+ val serializableConf = new SerializableConfiguration(conf)
val saveFunc = (rdd: RDD[(K, V)], time: Time) => {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsNewAPIHadoopFile(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index ffce6a4c3c74c..31ce8e1ec14d7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -23,12 +23,11 @@ import java.util.UUID
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import org.apache.commons.io.FileUtils
-
import org.apache.spark._
import org.apache.spark.rdd.BlockRDD
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.streaming.util._
+import org.apache.spark.util.SerializableConfiguration
/**
* Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]].
@@ -94,7 +93,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
// Hadoop configuration is not serializable, so broadcast it as a serializable.
@transient private val hadoopConfig = sc.hadoopConfiguration
- private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig)
+ private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig)
override def isValid(): Boolean = true
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 207d64d9414ee..c8dd6e06812dc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException}
/** Trait that represents the metadata related to storage of blocks */
private[streaming] trait ReceivedBlockStoreResult {
- def blockId: StreamBlockId // Any implementation of this trait will store a block id
+ // Any implementation of this trait will store a block id
+ def blockId: StreamBlockId
+ // Any implementation of this trait will have to return the number of records
+ def numRecords: Option[Long]
}
/** Trait that represents a class that handles the storage of blocks received by receiver */
@@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler {
* that stores the metadata related to storage of blocks using
* [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]]
*/
-private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId)
+private[streaming] case class BlockManagerBasedStoreResult(
+ blockId: StreamBlockId, numRecords: Option[Long])
extends ReceivedBlockStoreResult
@@ -64,11 +68,20 @@ private[streaming] class BlockManagerBasedBlockHandler(
extends ReceivedBlockHandler with Logging {
def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = {
+
+ var numRecords = None: Option[Long]
+
val putResult: Seq[(BlockId, BlockStatus)] = block match {
case ArrayBufferBlock(arrayBuffer) =>
- blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true)
+ numRecords = Some(arrayBuffer.size.toLong)
+ blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel,
+ tellMaster = true)
case IteratorBlock(iterator) =>
- blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true)
+ val countIterator = new CountingIterator(iterator)
+ val putResult = blockManager.putIterator(blockId, countIterator, storageLevel,
+ tellMaster = true)
+ numRecords = countIterator.count
+ putResult
case ByteBufferBlock(byteBuffer) =>
blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true)
case o =>
@@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler(
throw new SparkException(
s"Could not store $blockId to block manager with storage level $storageLevel")
}
- BlockManagerBasedStoreResult(blockId)
+ BlockManagerBasedStoreResult(blockId, numRecords)
}
def cleanupOldBlocks(threshTime: Long) {
@@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler(
*/
private[streaming] case class WriteAheadLogBasedStoreResult(
blockId: StreamBlockId,
+ numRecords: Option[Long],
walRecordHandle: WriteAheadLogRecordHandle
) extends ReceivedBlockStoreResult
@@ -151,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
*/
def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = {
+ var numRecords = None: Option[Long]
// Serialize the block so that it can be inserted into both
val serializedBlock = block match {
case ArrayBufferBlock(arrayBuffer) =>
+ numRecords = Some(arrayBuffer.size.toLong)
blockManager.dataSerialize(blockId, arrayBuffer.iterator)
case IteratorBlock(iterator) =>
- blockManager.dataSerialize(blockId, iterator)
+ val countIterator = new CountingIterator(iterator)
+ val serializedBlock = blockManager.dataSerialize(blockId, countIterator)
+ numRecords = countIterator.count
+ serializedBlock
case ByteBufferBlock(byteBuffer) =>
byteBuffer
case _ =>
@@ -181,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
// Combine the futures, wait for both to complete, and return the write ahead log record handle
val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2)
val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout)
- WriteAheadLogBasedStoreResult(blockId, walRecordHandle)
+ WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle)
}
def cleanupOldBlocks(threshTime: Long) {
@@ -199,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler {
new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString
}
}
+
+/**
+ * A utility that will wrap the Iterator to get the count
+ */
+private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] {
+ private var _count = 0
+
+ private def isFullyConsumed: Boolean = !iterator.hasNext
+
+ def hasNext(): Boolean = iterator.hasNext
+
+ def count(): Option[Long] = {
+ if (isFullyConsumed) Some(_count) else None
+ }
+
+ def next(): T = {
+ _count += 1
+ iterator.next()
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 8be732b64e3a3..6078cdf8f8790 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl(
blockIdOption: Option[StreamBlockId]
) {
val blockId = blockIdOption.getOrElse(nextBlockId)
- val numRecords = receivedBlock match {
- case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong)
- case _ => None
- }
-
val time = System.currentTimeMillis
val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock)
logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms")
-
+ val numRecords = blockStoreResult.numRecords
val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult)
trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo))
logDebug(s"Reported block $blockId")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index f1504b09c9873..e6cdbec11e94c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -21,10 +21,12 @@ import scala.collection.mutable.{HashMap, SynchronizedMap}
import scala.language.existentials
import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
+import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.rpc._
import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl,
+ StopReceiver}
+import org.apache.spark.util.SerializableConfiguration
/**
* Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -294,7 +296,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
val checkpointDirOption = Option(ssc.checkpointDir)
- val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration)
+ val serializableHadoopConf =
+ new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration)
// Function to start the receiver on the worker node
val startReceiver = (iterator: Iterator[Receiver[_]]) => {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index cca8cedb1d080..6c0c926755c20 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite
val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
val hadoopConf = new Configuration()
- val storageLevel = StorageLevel.MEMORY_ONLY_SER
val streamId = 1
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
@@ -57,10 +56,12 @@ class ReceivedBlockHandlerSuite
val serializer = new KryoSerializer(conf)
val manualClock = new ManualClock
val blockManagerSize = 10000000
+ val blockManagerBuffer = new ArrayBuffer[BlockManager]()
var rpcEnv: RpcEnv = null
var blockManagerMaster: BlockManagerMaster = null
var blockManager: BlockManager = null
+ var storageLevel: StorageLevel = null
var tempDirectory: File = null
before {
@@ -70,20 +71,21 @@ class ReceivedBlockHandlerSuite
blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
- blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer,
- blockManagerSize, conf, mapOutputTracker, shuffleManager,
- new NioBlockTransferService(conf, securityMgr), securityMgr, 0)
- blockManager.initialize("app-id")
+ storageLevel = StorageLevel.MEMORY_ONLY_SER
+ blockManager = createBlockManager(blockManagerSize, conf)
tempDirectory = Utils.createTempDir()
manualClock.setTime(0)
}
after {
- if (blockManager != null) {
- blockManager.stop()
- blockManager = null
+ for ( blockManager <- blockManagerBuffer ) {
+ if (blockManager != null) {
+ blockManager.stop()
+ }
}
+ blockManager = null
+ blockManagerBuffer.clear()
if (blockManagerMaster != null) {
blockManagerMaster.stop()
blockManagerMaster = null
@@ -174,6 +176,130 @@ class ReceivedBlockHandlerSuite
}
}
+ test("Test Block - count messages") {
+ // Test count with BlockManagedBasedBlockHandler
+ testCountWithBlockManagerBasedBlockHandler(true)
+ // Test count with WriteAheadLogBasedBlockHandler
+ testCountWithBlockManagerBasedBlockHandler(false)
+ }
+
+ test("Test Block - isFullyConsumed") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.storage.unrollMemoryThreshold", "512")
+ // spark.storage.unrollFraction set to 0.4 for BlockManager
+ sparkConf.set("spark.storage.unrollFraction", "0.4")
+ // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll
+ blockManager = createBlockManager(12000, sparkConf)
+
+ // there is not enough space to store this block in MEMORY,
+ // But BlockManager will be able to sereliaze this block to WAL
+ // and hence count returns correct value.
+ testRecordcount(false, StorageLevel.MEMORY_ONLY,
+ IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70))
+
+ // there is not enough space to store this block in MEMORY,
+ // But BlockManager will be able to sereliaze this block to DISK
+ // and hence count returns correct value.
+ testRecordcount(true, StorageLevel.MEMORY_AND_DISK,
+ IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70))
+
+ // there is not enough space to store this block With MEMORY_ONLY StorageLevel.
+ // BlockManager will not be able to unroll this block
+ // and hence it will not tryToPut this block, resulting the SparkException
+ storageLevel = StorageLevel.MEMORY_ONLY
+ withBlockManagerBasedBlockHandler { handler =>
+ val thrown = intercept[SparkException] {
+ storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator))
+ }
+ }
+ }
+
+ private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) {
+ // ByteBufferBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None)
+ // ByteBufferBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None)
+ // ArrayBufferBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25))
+ // ArrayBufferBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25))
+ // ArrayBufferBlock-DISK_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY,
+ ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50))
+ // ArrayBufferBlock-MEMORY_AND_DISK
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK,
+ ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75))
+ // IteratorBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100))
+ // IteratorBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100))
+ // IteratorBlock-DISK_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY,
+ IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125))
+ // IteratorBlock-MEMORY_AND_DISK
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK,
+ IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150))
+ }
+
+ private def createBlockManager(
+ maxMem: Long,
+ conf: SparkConf,
+ name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+ manager.initialize("app-id")
+ blockManagerBuffer += manager
+ manager
+ }
+
+ /**
+ * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks
+ * and verify the correct record count
+ */
+ private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean,
+ sLevel: StorageLevel,
+ receivedBlock: ReceivedBlock,
+ bManager: BlockManager,
+ expectedNumRecords: Option[Long]
+ ) {
+ blockManager = bManager
+ storageLevel = sLevel
+ var bId: StreamBlockId = null
+ try {
+ if (isBlockManagedBasedBlockHandler) {
+ // test received block with BlockManager based handler
+ withBlockManagerBasedBlockHandler { handler =>
+ val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock)
+ bId = blockId
+ assert(blockStoreResult.numRecords === expectedNumRecords,
+ "Message count not matches for a " +
+ receivedBlock.getClass.getName +
+ " being inserted using BlockManagerBasedBlockHandler with " + sLevel)
+ }
+ } else {
+ // test received block with WAL based handler
+ withWriteAheadLogBasedBlockHandler { handler =>
+ val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock)
+ bId = blockId
+ assert(blockStoreResult.numRecords === expectedNumRecords,
+ "Message count not matches for a " +
+ receivedBlock.getClass.getName +
+ " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel)
+ }
+ }
+ } finally {
+ // Removing the Block Id to use same blockManager for next test
+ blockManager.removeBlock(bId, true)
+ }
+ }
+
/**
* Test storing of data using different forms of ReceivedBlocks and verify that they succeeded
* using the given verification function
@@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite
(blockIds, storeResults)
}
+ /** Store single block using a handler */
+ private def storeSingleBlock(
+ handler: ReceivedBlockHandler,
+ block: ReceivedBlock
+ ): (StreamBlockId, ReceivedBlockStoreResult) = {
+ val blockId = generateBlockId
+ val blockStoreResult = handler.storeBlock(blockId, block)
+ logDebug("Done inserting")
+ (blockId, blockStoreResult)
+ }
+
private def getWriteAheadLogFiles(): Seq[String] = {
getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId))
}
private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong)
}
+
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index be305b5e0dfea..f793a12843b2f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite
/** Generate blocks infos using random ids */
def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None,
- BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)))))
+ BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L))))
}
/** Get all the data written in the given write ahead log file. */
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
index 1ace1a97d5156..33f580aaebdc0 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -115,8 +115,9 @@ private[spark] class YarnClusterSchedulerBackend(
val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
logDebug(s"Base URL for logs: $baseUrl")
- driverLogs = Some(
- Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0"))
+ driverLogs = Some(Map(
+ "stderr" -> s"$baseUrl/stderr?start=-4096",
+ "stdout" -> s"$baseUrl/stdout?start=-4096"))
}
}
} catch {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index a0f25ba450068..335e966519c7c 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -376,7 +376,7 @@ private object YarnClusterDriver extends Logging with Matchers {
new URL(urlStr)
val containerId = YarnSparkHadoopUtil.get.getContainerId
val user = Utils.getCurrentUserName()
- assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0"))
+ assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096"))
}
}
|